From 17a1270071c0a376afae33ad53dfe11cb9c6d4d3 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Sun, 8 Jan 2023 06:03:22 -0500 Subject: [PATCH 1/6] Support non-equijoin predicate for EliminateCrossJoin --- benchmarks/expected-plans/q11.txt | 35 +- benchmarks/expected-plans/q17.txt | 4 +- benchmarks/expected-plans/q19.txt | 11 +- benchmarks/expected-plans/q20.txt | 25 +- benchmarks/expected-plans/q22.txt | 21 +- benchmarks/expected-plans/q7.txt | 33 +- datafusion/core/tests/sql/joins.rs | 18 +- datafusion/core/tests/sql/predicates.rs | 11 +- datafusion/core/tests/sql/subqueries.rs | 172 +++--- datafusion/expr/src/utils.rs | 12 + .../optimizer/src/eliminate_cross_join.rs | 569 +++++++++--------- .../src/rewrite_disjunctive_predicate.rs | 49 +- .../optimizer/tests/integration-test.rs | 15 +- 13 files changed, 519 insertions(+), 456 deletions(-) diff --git a/benchmarks/expected-plans/q11.txt b/benchmarks/expected-plans/q11.txt index 7d8e145487ba8..32dc761b3eb8a 100644 --- a/benchmarks/expected-plans/q11.txt +++ b/benchmarks/expected-plans/q11.txt @@ -1,20 +1,19 @@ Sort: value DESC NULLS FIRST Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value - Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__scalar_sq_1.__value AS Decimal128(38, 15)) - CrossJoin: - Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_nationkey] - Filter: nation.n_name = Utf8("GERMANY") - TableScan: nation projection=[n_nationkey, n_name] - SubqueryAlias: __scalar_sq_1 - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value - Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_nationkey] - Filter: nation.n_name = Utf8("GERMANY") - TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file + Inner Join: Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__scalar_sq_1.__value AS Decimal128(38, 15)) + Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_nationkey] + Filter: nation.n_name = Utf8("GERMANY") + TableScan: nation projection=[n_nationkey, n_name] + SubqueryAlias: __scalar_sq_1 + Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value + Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] + Inner Join: supplier.s_nationkey = nation.n_nationkey + Inner Join: partsupp.ps_suppkey = supplier.s_suppkey + TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] + TableScan: supplier projection=[s_suppkey, s_nationkey] + Filter: nation.n_name = Utf8("GERMANY") + TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file diff --git a/benchmarks/expected-plans/q17.txt b/benchmarks/expected-plans/q17.txt index 755311c5ee106..bc495651b82c7 100644 --- a/benchmarks/expected-plans/q17.txt +++ b/benchmarks/expected-plans/q17.txt @@ -1,7 +1,7 @@ Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice)]] - Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < CAST(__scalar_sq_1.__value AS Decimal128(30, 15)) - Inner Join: part.p_partkey = __scalar_sq_1.l_partkey, lineitem.l_partkey = __scalar_sq_1.l_partkey + Projection: lineitem.l_extendedprice + Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < CAST(__scalar_sq_1.__value AS Decimal128(30, 15)) Inner Join: lineitem.l_partkey = part.p_partkey TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") diff --git a/benchmarks/expected-plans/q19.txt b/benchmarks/expected-plans/q19.txt index 969ad02d4c598..74cc961559c70 100644 --- a/benchmarks/expected-plans/q19.txt +++ b/benchmarks/expected-plans/q19.txt @@ -1,8 +1,7 @@ Projection: SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount) AS revenue Aggregate: groupBy=[[]], aggr=[[SUM(CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4))) AS SUM(lineitem.l_extendedprice * Int64(1) - lineitem.l_discount)]] - Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) - Inner Join: lineitem.l_partkey = part.p_partkey - Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") - TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode] - Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) - TableScan: part projection=[p_partkey, p_brand, p_size, p_container] + Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) + Filter: (lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)) AND (lineitem.l_shipmode = Utf8("AIR REG") OR lineitem.l_shipmode = Utf8("AIR")) AND lineitem.l_shipinstruct = Utf8("DELIVER IN PERSON") + TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice, l_discount, l_shipinstruct, l_shipmode] + Filter: (part.p_brand = Utf8("Brand#12") AND part.p_container IN ([Utf8("SM CASE"), Utf8("SM BOX"), Utf8("SM PACK"), Utf8("SM PKG")]) AND part.p_size <= Int32(5) OR part.p_brand = Utf8("Brand#23") AND part.p_container IN ([Utf8("MED BAG"), Utf8("MED BOX"), Utf8("MED PKG"), Utf8("MED PACK")]) AND part.p_size <= Int32(10) OR part.p_brand = Utf8("Brand#34") AND part.p_container IN ([Utf8("LG CASE"), Utf8("LG BOX"), Utf8("LG PACK"), Utf8("LG PKG")]) AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) + TableScan: part projection=[p_partkey, p_brand, p_size, p_container] diff --git a/benchmarks/expected-plans/q20.txt b/benchmarks/expected-plans/q20.txt index b7ecb9a091999..974ee4f9a66c1 100644 --- a/benchmarks/expected-plans/q20.txt +++ b/benchmarks/expected-plans/q20.txt @@ -7,16 +7,15 @@ Sort: supplier.s_name ASC NULLS LAST TableScan: nation projection=[n_nationkey, n_name] SubqueryAlias: __correlated_sq_1 Projection: partsupp.ps_suppkey AS ps_suppkey - Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_1.__value - Inner Join: partsupp.ps_partkey = __scalar_sq_1.l_partkey, partsupp.ps_suppkey = __scalar_sq_1.l_suppkey - LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] - SubqueryAlias: __correlated_sq_2 - Projection: part.p_partkey AS p_partkey - Filter: part.p_name LIKE Utf8("forest%") - TableScan: part projection=[p_partkey, p_name] - SubqueryAlias: __scalar_sq_1 - Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value - Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] - Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") - TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate] \ No newline at end of file + Inner Join: partsupp.ps_partkey = __scalar_sq_1.l_partkey, partsupp.ps_suppkey = __scalar_sq_1.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_1.__value + LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey + TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] + SubqueryAlias: __correlated_sq_2 + Projection: part.p_partkey AS p_partkey + Filter: part.p_name LIKE Utf8("forest%") + TableScan: part projection=[p_partkey, p_name] + SubqueryAlias: __scalar_sq_1 + Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value + Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] + Filter: lineitem.l_shipdate >= Date32("8766") AND lineitem.l_shipdate < Date32("9131") + TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate] \ No newline at end of file diff --git a/benchmarks/expected-plans/q22.txt b/benchmarks/expected-plans/q22.txt index 0fd7a590ac194..18b8a0371e69d 100644 --- a/benchmarks/expected-plans/q22.txt +++ b/benchmarks/expected-plans/q22.txt @@ -3,14 +3,13 @@ Sort: custsale.cntrycode ASC NULLS LAST Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(custsale.c_acctbal)]] SubqueryAlias: custsale Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal - Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_1.__value - CrossJoin: - LeftAnti Join: customer.c_custkey = orders.o_custkey - Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) - TableScan: customer projection=[c_custkey, c_phone, c_acctbal] - TableScan: orders projection=[o_custkey] - SubqueryAlias: __scalar_sq_1 - Projection: AVG(customer.c_acctbal) AS __value - Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] - Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) - TableScan: customer projection=[c_phone, c_acctbal] \ No newline at end of file + Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_1.__value + LeftAnti Join: customer.c_custkey = orders.o_custkey + Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) + TableScan: customer projection=[c_custkey, c_phone, c_acctbal] + TableScan: orders projection=[o_custkey] + SubqueryAlias: __scalar_sq_1 + Projection: AVG(customer.c_acctbal) AS __value + Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] + Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) + TableScan: customer projection=[c_phone, c_acctbal] \ No newline at end of file diff --git a/benchmarks/expected-plans/q7.txt b/benchmarks/expected-plans/q7.txt index bd8c10f8cf88f..e784aadf865a4 100644 --- a/benchmarks/expected-plans/q7.txt +++ b/benchmarks/expected-plans/q7.txt @@ -3,20 +3,19 @@ Sort: shipping.supp_nation ASC NULLS LAST, shipping.cust_nation ASC NULLS LAST, Aggregate: groupBy=[[shipping.supp_nation, shipping.cust_nation, shipping.l_year]], aggr=[[SUM(shipping.volume)]] SubqueryAlias: shipping Projection: n1.n_name AS supp_nation, n2.n_name AS cust_nation, datepart(Utf8("YEAR"), lineitem.l_shipdate) AS l_year, CAST(lineitem.l_extendedprice AS Decimal128(38, 4)) * CAST(Decimal128(Some(100),23,2) - CAST(lineitem.l_discount AS Decimal128(23, 2)) AS Decimal128(38, 4)) AS volume - Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") - Inner Join: customer.c_nationkey = n2.n_nationkey - Inner Join: supplier.s_nationkey = n1.n_nationkey - Inner Join: orders.o_custkey = customer.c_custkey - Inner Join: lineitem.l_orderkey = orders.o_orderkey - Inner Join: supplier.s_suppkey = lineitem.l_suppkey - TableScan: supplier projection=[s_suppkey, s_nationkey] - Filter: lineitem.l_shipdate >= Date32("9131") AND lineitem.l_shipdate <= Date32("9861") - TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate] - TableScan: orders projection=[o_orderkey, o_custkey] - TableScan: customer projection=[c_custkey, c_nationkey] - SubqueryAlias: n1 - Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY") - TableScan: nation projection=[n_nationkey, n_name] - SubqueryAlias: n2 - Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") - TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file + Inner Join: customer.c_nationkey = n2.n_nationkey Filter: n1.n_name = Utf8("FRANCE") AND n2.n_name = Utf8("GERMANY") OR n1.n_name = Utf8("GERMANY") AND n2.n_name = Utf8("FRANCE") + Inner Join: supplier.s_nationkey = n1.n_nationkey + Inner Join: orders.o_custkey = customer.c_custkey + Inner Join: lineitem.l_orderkey = orders.o_orderkey + Inner Join: supplier.s_suppkey = lineitem.l_suppkey + TableScan: supplier projection=[s_suppkey, s_nationkey] + Filter: lineitem.l_shipdate >= Date32("9131") AND lineitem.l_shipdate <= Date32("9861") + TableScan: lineitem projection=[l_orderkey, l_suppkey, l_extendedprice, l_discount, l_shipdate] + TableScan: orders projection=[o_orderkey, o_custkey] + TableScan: customer projection=[c_custkey, c_nationkey] + SubqueryAlias: n1 + Filter: nation.n_name = Utf8("FRANCE") OR nation.n_name = Utf8("GERMANY") + TableScan: nation projection=[n_nationkey, n_name] + SubqueryAlias: n2 + Filter: nation.n_name = Utf8("GERMANY") OR nation.n_name = Utf8("FRANCE") + TableScan: nation projection=[n_nationkey, n_name] \ No newline at end of file diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index 1de20c29cd079..f55779d300432 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -1557,15 +1557,13 @@ async fn reduce_left_join_2() -> Result<()> { // filter expr: `t2.t2_int < 10 or (t1.t1_int > 2 and t2.t2_name != 'w')` // could be write to: `(t1.t1_int > 2 or t2.t2_int < 10) and (t2.t2_name != 'w' or t2.t2_int < 10)` // the right part `(t2.t2_name != 'w' or t2.t2_int < 10)` could be push down left join side and remove in filter. - let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Inner Join: t1.t1_id = t2.t2_id Filter: t2.t2_int < UInt32(10) OR t1.t1_int > UInt32(2) AND t2.t2_name != Utf8(\"w\") [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " Filter: t2.t2_int < UInt32(10) OR t2.t2_name != Utf8(\"w\") [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); @@ -1688,13 +1686,13 @@ async fn reduce_right_join_2() -> Result<()> { let msg = format!("Creating logical plan for '{sql}'"); let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); let plan = dataframe.into_optimized_plan()?; + let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: t1.t1_id, t1.t1_name, t1.t1_int, t2.t2_id, t2.t2_name, t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " Inner Join: t1.t1_id = t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", - " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", - " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " Inner Join: t1.t1_id = t2.t2_id Filter: t1.t1_int != t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_name, t2_int] [t2_id:UInt32;N, t2_name:Utf8;N, t2_int:UInt32;N]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/core/tests/sql/predicates.rs b/datafusion/core/tests/sql/predicates.rs index 1e8888ce45f9d..4f93e5da8b23a 100644 --- a/datafusion/core/tests/sql/predicates.rs +++ b/datafusion/core/tests/sql/predicates.rs @@ -588,12 +588,11 @@ async fn multiple_or_predicates() -> Result<()> { let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", " Projection: lineitem.l_partkey [l_partkey:Int64]", - " Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " Inner Join: lineitem.l_partkey = part.p_partkey [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", - " Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", - " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Inner Join: lineitem.l_partkey = part.p_partkey Filter: part.p_brand = Utf8(\"Brand#12\") AND lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) AND part.p_size <= Int32(15) [l_partkey:Int64, l_quantity:Decimal128(15, 2), p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " Filter: lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2) [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " TableScan: lineitem projection=[l_partkey, l_quantity], partial_filters=[lineitem.l_quantity >= Decimal128(Some(100),15,2) AND lineitem.l_quantity <= Decimal128(Some(1100),15,2) OR lineitem.l_quantity >= Decimal128(Some(1000),15,2) AND lineitem.l_quantity <= Decimal128(Some(2000),15,2) OR lineitem.l_quantity >= Decimal128(Some(2000),15,2) AND lineitem.l_quantity <= Decimal128(Some(3000),15,2)] [l_partkey:Int64, l_quantity:Decimal128(15, 2)]", + " Filter: (part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)) AND part.p_size >= Int32(1) [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", + " TableScan: part projection=[p_partkey, p_brand, p_size], partial_filters=[part.p_size >= Int32(1), part.p_brand = Utf8(\"Brand#12\") AND part.p_size <= Int32(5) OR part.p_brand = Utf8(\"Brand#23\") AND part.p_size <= Int32(10) OR part.p_brand = Utf8(\"Brand#34\") AND part.p_size <= Int32(15)] [p_partkey:Int64, p_brand:Utf8, p_size:Int32]", ]; let formatted = plan.display_indent_schema().to_string(); let actual: Vec<&str> = formatted.trim().lines().collect(); diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index 3fff5ba3e80c9..1e0b13f32afc2 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -49,22 +49,21 @@ where c_acctbal < ( debug!("input:\n{}", dataframe.logical_plan().display_indent()); let plan = dataframe.into_optimized_plan().unwrap(); - let actual = format!("{}", plan.display_indent()); - let expected = r#"Sort: customer.c_custkey ASC NULLS LAST - Projection: customer.c_custkey - Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value - Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey - TableScan: customer projection=[c_custkey, c_acctbal] - SubqueryAlias: __scalar_sq_1 - Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value - Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] - Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value - Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey - TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice] - SubqueryAlias: __scalar_sq_2 - Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value - Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] - TableScan: lineitem projection=[l_orderkey, l_extendedprice]"#; + let actual = format!("{}", plan.display_indent()); + + let expected = "Sort: customer.c_custkey ASC NULLS LAST\ + \n Projection: customer.c_custkey\ + \n Inner Join: customer.c_custkey = __scalar_sq_1.o_custkey Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __scalar_sq_1.__value\ + \n TableScan: customer projection=[c_custkey, c_acctbal]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\ + \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\ + \n Inner Join: orders.o_orderkey = __scalar_sq_2.l_orderkey Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __scalar_sq_2.__value\ + \n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\ + \n SubqueryAlias: __scalar_sq_2\ + \n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\ + \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\ + \n TableScan: lineitem projection=[l_orderkey, l_extendedprice]"; assert_eq!(actual, expected); Ok(()) @@ -254,19 +253,19 @@ async fn tpch_q17_correlated() -> Result<()> { let dataframe = ctx.sql(sql).await.unwrap(); let plan = dataframe.into_optimized_plan().unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = r#"Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly - Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice)]] - Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < CAST(__scalar_sq_1.__value AS Decimal128(30, 15)) - Inner Join: part.p_partkey = __scalar_sq_1.l_partkey, lineitem.l_partkey = __scalar_sq_1.l_partkey - Inner Join: lineitem.l_partkey = part.p_partkey - TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice] - Filter: part.p_brand = Utf8("Brand#23") AND part.p_container = Utf8("MED BOX") - TableScan: part projection=[p_partkey, p_brand, p_container] - SubqueryAlias: __scalar_sq_1 - Projection: lineitem.l_partkey, Float64(0.2) * CAST(AVG(lineitem.l_quantity) AS Float64) AS __value - Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[AVG(lineitem.l_quantity)]] - TableScan: lineitem projection=[l_partkey, l_quantity]"# - .to_string(); + + let expected = "Projection: CAST(SUM(lineitem.l_extendedprice) AS Float64) / Float64(7) AS avg_yearly\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(lineitem.l_extendedprice)]]\ + \n Projection: lineitem.l_extendedprice\ + \n Inner Join: part.p_partkey = __scalar_sq_1.l_partkey Filter: CAST(lineitem.l_quantity AS Decimal128(30, 15)) < CAST(__scalar_sq_1.__value AS Decimal128(30, 15))\ + \n Inner Join: lineitem.l_partkey = part.p_partkey\ + \n TableScan: lineitem projection=[l_partkey, l_quantity, l_extendedprice]\ + \n Filter: part.p_brand = Utf8(\"Brand#23\") AND part.p_container = Utf8(\"MED BOX\")\ + \n TableScan: part projection=[p_partkey, p_brand, p_container]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Projection: lineitem.l_partkey, Float64(0.2) * CAST(AVG(lineitem.l_quantity) AS Float64) AS __value\ + \n Aggregate: groupBy=[[lineitem.l_partkey]], aggr=[[AVG(lineitem.l_quantity)]]\ + \n TableScan: lineitem projection=[l_partkey, l_quantity]"; assert_eq!(actual, expected); // assert data @@ -309,28 +308,28 @@ order by s_name; let dataframe = ctx.sql(sql).await.unwrap(); let plan = dataframe.into_optimized_plan().unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = r#"Sort: supplier.s_name ASC NULLS LAST - Projection: supplier.s_name, supplier.s_address - LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey - Inner Join: supplier.s_nationkey = nation.n_nationkey - TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] - Filter: nation.n_name = Utf8("CANADA") - TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("CANADA")] - SubqueryAlias: __correlated_sq_1 - Projection: partsupp.ps_suppkey AS ps_suppkey - Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_1.__value - Inner Join: partsupp.ps_partkey = __scalar_sq_1.l_partkey, partsupp.ps_suppkey = __scalar_sq_1.l_suppkey - LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] - SubqueryAlias: __correlated_sq_2 - Projection: part.p_partkey AS p_partkey - Filter: part.p_name LIKE Utf8("forest%") - TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8("forest%")] - SubqueryAlias: __scalar_sq_1 - Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value - Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]] - Filter: lineitem.l_shipdate >= Date32("8766") - TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32("8766")]"#; + let expected = "Sort: supplier.s_name ASC NULLS LAST\ + \n Projection: supplier.s_name, supplier.s_address\ + \n LeftSemi Join: supplier.s_suppkey = __correlated_sq_1.ps_suppkey\ + \n Inner Join: supplier.s_nationkey = nation.n_nationkey\ + \n TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]\ + \n Filter: nation.n_name = Utf8(\"CANADA\")\ + \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"CANADA\")]\ + \n SubqueryAlias: __correlated_sq_1\ + \n Projection: partsupp.ps_suppkey AS ps_suppkey\ + \n Inner Join: partsupp.ps_partkey = __scalar_sq_1.l_partkey, partsupp.ps_suppkey = __scalar_sq_1.l_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __scalar_sq_1.__value\ + \n LeftSemi Join: partsupp.ps_partkey = __correlated_sq_2.p_partkey\ + \n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]\ + \n SubqueryAlias: __correlated_sq_2\ + \n Projection: part.p_partkey AS p_partkey\ + \n Filter: part.p_name LIKE Utf8(\"forest%\")\ + \n TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8(\"forest%\")]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Projection: lineitem.l_partkey, lineitem.l_suppkey, Float64(0.5) * CAST(SUM(lineitem.l_quantity) AS Float64) AS __value\ + \n Aggregate: groupBy=[[lineitem.l_partkey, lineitem.l_suppkey]], aggr=[[SUM(lineitem.l_quantity)]]\ + \n Filter: lineitem.l_shipdate >= Date32(\"8766\")\ + \n TableScan: lineitem projection=[l_partkey, l_suppkey, l_quantity, l_shipdate], partial_filters=[lineitem.l_shipdate >= Date32(\"8766\")]"; + assert_eq!(actual, expected); // assert data @@ -364,22 +363,22 @@ order by cntrycode;"#; let dataframe = ctx.sql(sql).await.unwrap(); let plan = dataframe.into_optimized_plan().unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = r#"Sort: custsale.cntrycode ASC NULLS LAST - Projection: custsale.cntrycode, COUNT(UInt8(1)) AS numcust, SUM(custsale.c_acctbal) AS totacctbal - Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(custsale.c_acctbal)]] - SubqueryAlias: custsale - Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal - Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_1.__value - CrossJoin: - LeftAnti Join: customer.c_custkey = orders.o_custkey - Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) - TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])] - TableScan: orders projection=[o_custkey] - SubqueryAlias: __scalar_sq_1 - Projection: AVG(customer.c_acctbal) AS __value - Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]] - Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]) - TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")]), customer.c_acctbal > Decimal128(Some(0),15,2)]"#; + let expected = "Sort: custsale.cntrycode ASC NULLS LAST\ + \n Projection: custsale.cntrycode, COUNT(UInt8(1)) AS numcust, SUM(custsale.c_acctbal) AS totacctbal\ + \n Aggregate: groupBy=[[custsale.cntrycode]], aggr=[[COUNT(UInt8(1)), SUM(custsale.c_acctbal)]]\ + \n SubqueryAlias: custsale\ + \n Projection: substr(customer.c_phone, Int64(1), Int64(2)) AS cntrycode, customer.c_acctbal\ + \n Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_1.__value\ + \n LeftAnti Join: customer.c_custkey = orders.o_custkey\ + \n Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8(\"13\"), Utf8(\"31\"), Utf8(\"23\"), Utf8(\"29\"), Utf8(\"30\"), Utf8(\"18\"), Utf8(\"17\")])\ + \n TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8(\"13\"), Utf8(\"31\"), Utf8(\"23\"), Utf8(\"29\"), Utf8(\"30\"), Utf8(\"18\"), Utf8(\"17\")])]\ + \n TableScan: orders projection=[o_custkey]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Projection: AVG(customer.c_acctbal) AS __value\ + \n Aggregate: groupBy=[[]], aggr=[[AVG(customer.c_acctbal)]]\ + \n Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8(\"13\"), Utf8(\"31\"), Utf8(\"23\"), Utf8(\"29\"), Utf8(\"30\"), Utf8(\"18\"), Utf8(\"17\")])\ + \n TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2) AS customer.c_acctbal > Decimal128(Some(0),30,15), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8(\"13\"), Utf8(\"31\"), Utf8(\"23\"), Utf8(\"29\"), Utf8(\"30\"), Utf8(\"18\"), Utf8(\"17\")]), customer.c_acctbal > Decimal128(Some(0),15,2)]"; + assert_eq!(expected, actual); // assert data @@ -420,26 +419,25 @@ order by value desc; let dataframe = ctx.sql(sql).await.unwrap(); let plan = dataframe.into_optimized_plan().unwrap(); let actual = format!("{}", plan.display_indent()); - let expected = r#"Sort: value DESC NULLS FIRST - Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value - Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__scalar_sq_1.__value AS Decimal128(38, 15)) - CrossJoin: - Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_nationkey] - Filter: nation.n_name = Utf8("GERMANY") - TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")] - SubqueryAlias: __scalar_sq_1 - Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value - Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]] - Inner Join: supplier.s_nationkey = nation.n_nationkey - Inner Join: partsupp.ps_suppkey = supplier.s_suppkey - TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost] - TableScan: supplier projection=[s_suppkey, s_nationkey] - Filter: nation.n_name = Utf8("GERMANY") - TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8("GERMANY")]"#; + let expected = "Sort: value DESC NULLS FIRST\ + \n Projection: partsupp.ps_partkey, SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS value\ + \n Inner Join: Filter: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Decimal128(38, 15)) > CAST(__scalar_sq_1.__value AS Decimal128(38, 15))\ + \n Aggregate: groupBy=[[partsupp.ps_partkey]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]\ + \n Inner Join: supplier.s_nationkey = nation.n_nationkey\ + \n Inner Join: partsupp.ps_suppkey = supplier.s_suppkey\ + \n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty, ps_supplycost]\ + \n TableScan: supplier projection=[s_suppkey, s_nationkey]\ + \n Filter: nation.n_name = Utf8(\"GERMANY\")\ + \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"GERMANY\")]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Projection: CAST(SUM(partsupp.ps_supplycost * partsupp.ps_availqty) AS Float64) * Float64(0.0001) AS __value\ + \n Aggregate: groupBy=[[]], aggr=[[SUM(CAST(partsupp.ps_supplycost AS Decimal128(26, 2)) * CAST(partsupp.ps_availqty AS Decimal128(26, 2)))]]\ + \n Inner Join: supplier.s_nationkey = nation.n_nationkey\ + \n Inner Join: partsupp.ps_suppkey = supplier.s_suppkey\ + \n TableScan: partsupp projection=[ps_suppkey, ps_availqty, ps_supplycost]\ + \n TableScan: supplier projection=[s_suppkey, s_nationkey]\ + \n Filter: nation.n_name = Utf8(\"GERMANY\")\ + \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"GERMANY\")]"; assert_eq!(actual, expected); // assert data diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 8cf79c612e323..7d5bc9f37c14e 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -957,6 +957,18 @@ fn check_all_column_from_schema(columns: &HashSet, schema: DFSchemaRef) .all(|column| schema.index_of_column(column).is_ok()) } +/// Check whether all columns are from the schemas. +pub fn is_valid_join_predicate(expr: &Expr, schemas: &[DFSchemaRef]) -> Result { + let cols = expr.to_columns()?; + let is_valid = cols.iter().all(|col| { + schemas + .iter() + .any(|schema| schema.index_of_column(col).is_ok()) + }); + + Ok(is_valid) +} + /// Give two sides of the equijoin predicate, return a valid join key pair. /// If there is no valid join key pair, return None. /// diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 458eab95905a1..7b88823a66a0b 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -16,17 +16,17 @@ // under the License. //! Optimizer rule to eliminate cross join to inner join if join predicates are available in filters. -use std::collections::HashSet; use std::sync::Arc; +use crate::utils::{conjunction, split_conjunction_owned}; use crate::{utils, OptimizerConfig, OptimizerRule}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, }; -use datafusion_expr::utils::{can_hash, find_valid_equijoin_key_pair}; -use datafusion_expr::{and, build_join_schema, or, ExprSchemable, Operator}; +use datafusion_expr::utils::{find_valid_equijoin_key_pair, is_valid_join_predicate}; +use datafusion_expr::{build_join_schema, LogicalPlanBuilder, Operator}; #[derive(Default)] pub struct EliminateCrossJoin; @@ -58,20 +58,22 @@ impl OptimizerRule for EliminateCrossJoin { LogicalPlan::Filter(filter) => { let input = filter.input.as_ref().clone(); - let mut possible_join_keys: Vec<(Expr, Expr)> = vec![]; let mut all_inputs: Vec = vec![]; + let mut possible_join_predicates: Vec = + split_conjunction_owned(filter.predicate.clone()); + match &input { LogicalPlan::Join(join) if (join.join_type == JoinType::Inner) => { flatten_join_inputs( &input, - &mut possible_join_keys, + &mut possible_join_predicates, &mut all_inputs, )?; } LogicalPlan::CrossJoin(_) => { flatten_join_inputs( &input, - &mut possible_join_keys, + &mut possible_join_predicates, &mut all_inputs, )?; } @@ -80,19 +82,12 @@ impl OptimizerRule for EliminateCrossJoin { } } - let predicate = &filter.predicate; - // join keys are handled locally - let mut all_join_keys: HashSet<(Expr, Expr)> = HashSet::new(); - - extract_possible_join_keys(predicate, &mut possible_join_keys)?; - let mut left = all_inputs.remove(0); while !all_inputs.is_empty() { left = find_inner_join( &left, &mut all_inputs, - &mut possible_join_keys, - &mut all_join_keys, + &mut possible_join_predicates, )?; } @@ -105,20 +100,15 @@ impl OptimizerRule for EliminateCrossJoin { )); } - // if there are no join keys then do nothing. - if all_join_keys.is_empty() { + // If the rest predicate is not empty. + let other_predicate = conjunction(possible_join_predicates); + if let Some(predicate) = other_predicate { Ok(Some(LogicalPlan::Filter(Filter::try_new( - predicate.clone(), + predicate, Arc::new(left), )?))) } else { - // remove join expressions from filter - match remove_join_expressions(predicate, &all_join_keys)? { - Some(filter_expr) => Ok(Some(LogicalPlan::Filter( - Filter::try_new(filter_expr, Arc::new(left))?, - ))), - _ => Ok(Some(left)), - } + Ok(Some(left)) } } @@ -133,12 +123,24 @@ impl OptimizerRule for EliminateCrossJoin { fn flatten_join_inputs( plan: &LogicalPlan, - possible_join_keys: &mut Vec<(Expr, Expr)>, + possible_join_predicates: &mut Vec, all_inputs: &mut Vec, ) -> Result<()> { let children = match plan { LogicalPlan::Join(join) => { - possible_join_keys.extend(join.on.clone()); + let equijoin_predicates: Vec = join + .on + .iter() + .map(|(l, r)| Expr::eq(l.clone(), r.clone())) + .collect(); + let non_equijoin_predicates = join + .filter + .as_ref() + .map(|expr| split_conjunction_owned(expr.clone())) + .unwrap_or_else(Vec::new); + possible_join_predicates.extend(equijoin_predicates); + possible_join_predicates.extend(non_equijoin_predicates); + let left = &*(join.left); let right = &*(join.right); Ok::, DataFusionError>(vec![left, right]) @@ -159,13 +161,13 @@ fn flatten_join_inputs( match *child { LogicalPlan::Join(left_join) => { if left_join.join_type == JoinType::Inner { - flatten_join_inputs(child, possible_join_keys, all_inputs)?; + flatten_join_inputs(child, possible_join_predicates, all_inputs)?; } else { all_inputs.push((*child).clone()); } } LogicalPlan::CrossJoin(_) => { - flatten_join_inputs(child, possible_join_keys, all_inputs)?; + flatten_join_inputs(child, possible_join_predicates, all_inputs)?; } _ => all_inputs.push((*child).clone()), } @@ -176,159 +178,138 @@ fn flatten_join_inputs( fn find_inner_join( left_input: &LogicalPlan, rights: &mut Vec, - possible_join_keys: &mut Vec<(Expr, Expr)>, - all_join_keys: &mut HashSet<(Expr, Expr)>, + possible_join_predicates: &mut Vec, ) -> Result { + let mut candidate_right_input: Option<(usize, Vec)> = None; for (i, right_input) in rights.iter().enumerate() { - let mut join_keys = vec![]; - - for (l, r) in &mut *possible_join_keys { - let key_pair = find_valid_equijoin_key_pair( - l, - r, - left_input.schema().clone(), - right_input.schema().clone(), - )?; + let predicate_schemas = + [left_input.schema().clone(), right_input.schema().clone()]; + + // equijoin_flag indicct if join_predicates contains equijoin expr. + let (join_predicates, equijoin_flag): (Vec, bool) = + possible_join_predicates.iter().try_fold( + (Vec::::new(), false), + |(mut join_filters, mut equijoin_flag), expr| { + let is_join_filter = + is_valid_join_predicate(expr, &predicate_schemas)?; + if is_join_filter { + join_filters.push(expr.clone()); + } - // Save join keys - match key_pair { - Some((valid_l, valid_r)) => { - if can_hash(&valid_l.get_type(left_input.schema())?) { - join_keys.push((valid_l, valid_r)); + let has_equijoin = match expr { + Expr::BinaryExpr(BinaryExpr { + left, + op: Operator::Eq, + right, + }) => find_valid_equijoin_key_pair( + left, + right, + left_input.schema().clone(), + right_input.schema().clone(), + )? + .is_some(), + _ => false, + }; + + if has_equijoin && !equijoin_flag { + equijoin_flag = !equijoin_flag; } - } - _ => continue, - } - } - if !join_keys.is_empty() { - all_join_keys.extend(join_keys.clone()); - let right_input = rights.remove(i); - let join_schema = Arc::new(build_join_schema( - left_input.schema(), - right_input.schema(), - &JoinType::Inner, - )?); - - return Ok(LogicalPlan::Join(Join { - left: Arc::new(left_input.clone()), - right: Arc::new(right_input), - join_type: JoinType::Inner, - join_constraint: JoinConstraint::On, - on: join_keys, - filter: None, - schema: join_schema, - null_equals_null: false, - })); - } - } - let right = rights.remove(0); - let join_schema = Arc::new(build_join_schema( - left_input.schema(), - right.schema(), - &JoinType::Inner, - )?); - - Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left_input.clone()), - right: Arc::new(right), - schema: join_schema, - })) -} + Result::Ok((join_filters, equijoin_flag)) + }, + )?; -fn intersect( - accum: &mut Vec<(Expr, Expr)>, - vec1: &[(Expr, Expr)], - vec2: &[(Expr, Expr)], -) { - for x1 in vec1.iter() { - for x2 in vec2.iter() { - if x1.0 == x2.0 && x1.1 == x2.1 || x1.1 == x2.0 && x1.0 == x2.1 { - accum.push((x1.0.clone(), x1.1.clone())); - } - } - } -} + // if there are equijoin expr, choose it first. + if !join_predicates.is_empty() && equijoin_flag { + possible_join_predicates + .retain(|expr| join_predicates.iter().all(|filter| expr != filter)); -/// Extract join keys from a WHERE clause -fn extract_possible_join_keys(expr: &Expr, accum: &mut Vec<(Expr, Expr)>) -> Result<()> { - if let Expr::BinaryExpr(BinaryExpr { left, op, right }) = expr { - match op { - Operator::Eq => { - // Ensure that we don't add the same Join keys multiple times - if !(accum.contains(&(*left.clone(), *right.clone())) - || accum.contains(&(*right.clone(), *left.clone()))) - { - accum.push((*left.clone(), *right.clone())); - } - } - Operator::And => { - extract_possible_join_keys(left, accum)?; - extract_possible_join_keys(right, accum)? - } - // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Operator::Or => { - let mut left_join_keys = vec![]; - let mut right_join_keys = vec![]; + let join_filter = conjunction(join_predicates); + let right_input = rights.remove(i); - extract_possible_join_keys(left, &mut left_join_keys)?; - extract_possible_join_keys(right, &mut right_join_keys)?; + return LogicalPlanBuilder::from(left_input.clone()) + .join( + right_input, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + join_filter, + )? + .build(); + // let join_schema = Arc::new(build_join_schema( + // left_input.schema(), + // right_input.schema(), + // &JoinType::Inner, + // )?); + + // return Ok(LogicalPlan::Join(Join { + // left: Arc::new(left_input.clone()), + // right: Arc::new(right_input), + // join_type: JoinType::Inner, + // join_constraint: JoinConstraint::On, + // on: Vec::<(Expr, Expr)>::new(), + // filter: join_filter, + // schema: join_schema, + // null_equals_null: false, + // })); + } - intersect(accum, &left_join_keys, &right_join_keys) - } - _ => (), - }; + // If there is non equijoin predicate for the left and right, + // choose the first plan whose join predicates is not empty. + if !join_predicates.is_empty() && candidate_right_input.is_none() { + candidate_right_input = Some((i, join_predicates)); + } } - Ok(()) -} -/// Remove join expressions from a filter expression -/// Returns Some() when there are few remaining predicates in filter_expr -/// Returns None otherwise -fn remove_join_expressions( - expr: &Expr, - join_keys: &HashSet<(Expr, Expr)>, -) -> Result> { - match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { - Operator::Eq => { - if join_keys.contains(&(*left.clone(), *right.clone())) - || join_keys.contains(&(*right.clone(), *left.clone())) - { - Ok(None) - } else { - Ok(Some(expr.clone())) - } - } - Operator::And => { - let l = remove_join_expressions(left, join_keys)?; - let r = remove_join_expressions(right, join_keys)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(and(ll, rr))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), - } - } - // Fix for issue#78 join predicates from inside of OR expr also pulled up properly. - Operator::Or => { - let l = remove_join_expressions(left, join_keys)?; - let r = remove_join_expressions(right, join_keys)?; - match (l, r) { - (Some(ll), Some(rr)) => Ok(Some(or(ll, rr))), - (Some(ll), _) => Ok(Some(ll)), - (_, Some(rr)) => Ok(Some(rr)), - _ => Ok(None), - } - } - _ => Ok(Some(expr.clone())), - }, - _ => Ok(Some(expr.clone())), + if let Some((right_idx, join_filters)) = candidate_right_input { + let right_input = rights.remove(right_idx); + possible_join_predicates + .retain(|expr| join_filters.iter().all(|filter| expr != filter)); + let join_filter = conjunction(join_filters); + + LogicalPlanBuilder::from(left_input.clone()) + .join( + right_input, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + join_filter, + )? + .build() + // let join_schema = Arc::new(build_join_schema( + // left_input.schema(), + // right_input.schema(), + // &JoinType::Inner, + // )?); + + // Ok(LogicalPlan::Join(Join { + // left: Arc::new(left_input.clone()), + // right: Arc::new(right_input), + // join_type: JoinType::Inner, + // join_constraint: JoinConstraint::On, + // on: Vec::<(Expr, Expr)>::new(), + // filter: join_filter, + // schema: join_schema, + // null_equals_null: false, + // })) + } else { + // if no join predicate exists, choose first with cross join + let right = rights.remove(0); + let join_schema = Arc::new(build_join_schema( + left_input.schema(), + right.schema(), + &JoinType::Inner, + )?); + + Ok(LogicalPlan::CrossJoin(CrossJoin { + left: Arc::new(left_input.clone()), + right: Arc::new(right), + schema: join_schema, + })) } } #[cfg(test)] mod tests { + use datafusion_common::Column; use datafusion_expr::{ binary_expr, col, lit, logical_plan::builder::LogicalPlanBuilder, @@ -373,10 +354,9 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -400,11 +380,16 @@ mod tests { ))? .build()?; + // let expected = vec![ + // "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + // " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + // " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + // " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + // ]; let expected = vec![ - "Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a = t2.a OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -428,10 +413,9 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(20) AND t1.a = t2.a AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -459,10 +443,9 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -489,11 +472,11 @@ mod tests { .build()?; let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.b = t2.b AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; + assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -515,11 +498,11 @@ mod tests { .build()?; let expected = vec![ - "Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; + assert_optimized_plan_eq(&plan, expected); Ok(()) @@ -558,13 +541,12 @@ mod tests { .build()?; let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a = t1.a AND t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -631,17 +613,18 @@ mod tests { ))? .build()?; + // the filter of (t1 join t2): t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688). + // the filter of (t3 join t4): t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b + // the outer filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688) + // they are all `OR` expressions, so can not be split. let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -706,17 +689,17 @@ mod tests { ))? .build()?; + // the filter of (t1 join t2): t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688). + // the filter of (t3 join t4): t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b. + // the outer filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) let expected = vec![ - "Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -781,17 +764,17 @@ mod tests { ))? .build()?; + // the filter of (t1 join t2): t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688). + // the filter of (t3 join t4): t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b. + // the outer filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a OR t4.c = UInt32(688) let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t1.a = t2.a AND t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -860,17 +843,17 @@ mod tests { ))? .build()?; + // the filter of (t1 join t2): t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688). + // the filter of (t3 join t4): t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b. + // the outer filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688) let expected = vec![ - "Filter: t4.c < UInt32(15) OR t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " Filter: t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t1.a = t2.a OR t2.c < UInt32(15) OR t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -944,16 +927,17 @@ mod tests { ))? .build()?; + // the filter of (t1 join t2): (t1.a = t2.a OR t2.c < UInt32(15)) AND t1.a = t2.a AND t2.c = UInt32(688). + // the filter of (t3 join (t1 + t2)): empty. + // the outer filter: (t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688)) AND (t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b) let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Filter: t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: (t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688)) AND (t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: (t1.a = t2.a OR t2.c < UInt32(15)) AND t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -1032,15 +1016,17 @@ mod tests { ))? .build()?; + // the filter of (t1 join t2): (t1.a = t2.a OR t2.c < UInt32(15)) AND t1.a = t2.a AND t2.c = UInt32(688). + // the filter of (t3 join (t1 + t2)): empty. + // the outer filter: (t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688)) AND (t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b) let expected = vec![ - "Filter: (t4.c < UInt32(15) OR t4.c = UInt32(688)) AND (t4.c < UInt32(15) OR t3.c = UInt32(688) OR t3.b = t4.b) AND t2.c < UInt32(15) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a = t4.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: (t3.a = t1.a AND t4.c < UInt32(15) OR t3.a = t1.a AND t4.c = UInt32(688)) AND (t3.a = t4.a AND t4.c < UInt32(15) OR t3.a = t4.a AND t3.c = UInt32(688) OR t3.a = t4.a AND t3.b = t4.b) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: (t1.a = t2.a OR t2.c < UInt32(15)) AND t1.a = t2.a AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t4 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -1064,10 +1050,10 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]"]; + "Inner Join: Filter: t1.a + UInt32(100) = t2.a * UInt32(2) AND t2.c < UInt32(20) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; assert_optimized_plan_eq(&plan, expected); @@ -1091,10 +1077,9 @@ mod tests { .build()?; let expected = vec![ - "Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " CrossJoin: [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a + UInt32(100) = t2.a * UInt32(2) OR t2.b = t1.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -1119,10 +1104,9 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(20) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a + UInt32(100) = t2.a * UInt32(2) AND t2.c < UInt32(20) AND t1.a + UInt32(100) = t2.a * UInt32(2) AND t2.c = UInt32(10) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -1147,10 +1131,9 @@ mod tests { .build()?; let expected = vec![ - "Filter: t2.c < UInt32(15) OR t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Inner Join: Filter: t1.a + UInt32(100) = t2.a * UInt32(2) AND t2.c < UInt32(15) OR t1.a + UInt32(100) = t2.a * UInt32(2) AND t2.c = UInt32(688) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); @@ -1184,13 +1167,45 @@ mod tests { .build()?; let expected = vec![ - "Filter: t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " Inner Join: t1.a * UInt32(2) = t3.a + UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", - " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + "Projection: t1.a, t1.b, t1.c, t2.a, t2.b, t2.c, t3.a, t3.b, t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a + UInt32(100) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t3.a + UInt32(100) = t1.a * UInt32(2) AND t3.c < UInt32(15) AND t3.b < UInt32(15) [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + ]; + + assert_optimized_plan_eq(&plan, expected); + + Ok(()) + } + + #[test] + fn inner_join_with_join_filter() -> Result<()> { + let t1 = test_table_scan_with_name("t1")?; + let t2 = test_table_scan_with_name("t2")?; + let t3 = test_table_scan_with_name("t3")?; + + let t1_join_t2 = LogicalPlanBuilder::from(t1) + .join( + t2, + JoinType::Inner, + (Vec::::new(), Vec::::new()), + Some(col("t1.b").gt(col("t2.b"))), + )? + .build()?; + + let plan = LogicalPlanBuilder::from(t1_join_t2) + .cross_join(t3)? + .filter((col("t1.a").gt(col("t2.a"))).and(col("t1.c").lt(col("t3.a"))))? + .build()?; + + let expected = vec![ + "Inner Join: Filter: t1.c < t3.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " Inner Join: Filter: t1.a > t2.a AND t1.b > t2.b [a:UInt32, b:UInt32, c:UInt32, a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]", + " TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]", ]; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 57513fa4fff41..b571512468679 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -152,13 +152,48 @@ impl OptimizerRule for RewriteDisjunctivePredicate { } } -#[derive(Clone, PartialEq, Debug)] +#[derive(Clone, Debug)] enum Predicate { And { args: Vec }, Or { args: Vec }, Other { expr: Box }, } +impl PartialEq for Predicate { + fn eq(&self, other: &Predicate) -> bool { + match (self, other) { + // For BinaryExpr, A = B should be the same as B = A + // when extract A = B in expr like (A = B and C > D) OR (B = A and C = E). + (Predicate::Other { expr: l }, Predicate::Other { expr: r }) => { + match (l.as_ref(), r.as_ref()) { + ( + Expr::BinaryExpr(BinaryExpr { + left: ll, + op: l_op, + right: lr, + }), + Expr::BinaryExpr(BinaryExpr { + left: rl, + op: r_op, + right: rr, + }), + ) => { + l_op == r_op && ((ll == rl && lr == rr) || (ll == rr && lr == rl)) + } + _ => false, + } + } + (Predicate::And { args: l_args }, Predicate::And { args: r_args }) => { + l_args == r_args + } + (Predicate::Or { args: l_args }, Predicate::Or { args: r_args }) => { + l_args == r_args + } + _ => false, + } + } +} + fn predicate(expr: &Expr) -> Result { match expr { Expr::BinaryExpr(BinaryExpr { left, op, right }) => match op { @@ -424,4 +459,16 @@ mod tests { assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr))); Ok(()) } + + #[test] + fn test_ignore_binay_order() { + let equi_expr1 = col("t1.a").eq(col("t2.b")); + let equi_expr2 = col("t1.b").eq(col("t2.a")); + + let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1)))); + let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2)))); + let expr = or(and(equi_expr1, gt_expr), and(equi_expr2, lt_expr)); + + println!("{:?}", rewrite_predicate(predicate(&expr).unwrap())); + } } diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index f901e33d41d67..70f66121c7ea1 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -65,14 +65,13 @@ fn subquery_filter_with_cast() -> Result<()> { )"; let plan = test_sql(sql)?; let expected = "Projection: test.col_int32\ - \n Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.__value\ - \n CrossJoin:\ - \n TableScan: test projection=[col_int32]\ - \n SubqueryAlias: __scalar_sq_1\ - \n Projection: AVG(test.col_int32) AS __value\ - \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ - \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ - \n TableScan: test projection=[col_int32, col_utf8]"; + \n Inner Join: Filter: CAST(test.col_int32 AS Float64) > __scalar_sq_1.__value\ + \n TableScan: test projection=[col_int32]\ + \n SubqueryAlias: __scalar_sq_1\ + \n Projection: AVG(test.col_int32) AS __value\ + \n Aggregate: groupBy=[[]], aggr=[[AVG(test.col_int32)]]\ + \n Filter: test.col_utf8 >= Utf8(\"2002-05-08\") AND test.col_utf8 <= Utf8(\"2002-05-13\")\ + \n TableScan: test projection=[col_int32, col_utf8]"; assert_eq!(expected, format!("{plan:?}")); Ok(()) } From a95ee418e01da53b5c0b910d3f2fe77195bf9b73 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 10 Jan 2023 07:29:27 -0500 Subject: [PATCH 2/6] add tests --- datafusion/core/tests/sql/joins.rs | 99 +++++++++++++++++++ datafusion/core/tests/sql/mod.rs | 34 +++++++ .../core/tests/sqllogictests/src/join.rs | 0 .../tests/sqllogictests/test_files/join.slt | 44 +++++++++ 4 files changed, 177 insertions(+) create mode 100644 datafusion/core/tests/sqllogictests/src/join.rs create mode 100644 datafusion/core/tests/sqllogictests/test_files/join.slt diff --git a/datafusion/core/tests/sql/joins.rs b/datafusion/core/tests/sql/joins.rs index f55779d300432..6bdcdfe73743c 100644 --- a/datafusion/core/tests/sql/joins.rs +++ b/datafusion/core/tests/sql/joins.rs @@ -2808,3 +2808,102 @@ async fn type_coercion_join_with_filter_and_equi_expr() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn reduce_two_table_with_join_filter() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \ + from t1 \ + inner join t2 on t1.t1_id * 4 < t2.t2_id \ + where t1.t1_int > t2.t2_id"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: Filter: t1.t1_int > t2.t2_id AND CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn reduce_three_table_with_join_filter() -> Result<()> { + let ctx = create_muti_table_join_context()?; + + let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \ + from t1 \ + inner join t2 on t1.t1_id * 4 < t2.t2_id \ + cross join t3 \ + where t1.t1_int > t3.t3_id"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: Filter: t1.t1_int > t3.t3_id [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t3_id:UInt32;N]", + " Inner Join: Filter: CAST(t1.t1_id AS Int64) * Int64(4) < CAST(t2.t2_id AS Int64) [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 projection=[t2_id] [t2_id:UInt32;N]", + " TableScan: t3 projection=[t3_id] [t3_id:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} + +#[tokio::test] +async fn reduce_both_with_eq_and_non_eq_expr() -> Result<()> { + let ctx = create_join_context("t1_id", "t2_id", false)?; + + let sql = "select t1.t1_id, t1.t1_name, t2.t2_id \ + from t1 \ + inner join t2 on t1.t1_id * 4 = t2.t2_id and t1.t1_int > t2.t2_int \ + where t1.t1_int + 10 < t2.t2_id"; + + // assert logical plan + let msg = format!("Creating logical plan for '{sql}'"); + let dataframe = ctx.sql(&("explain ".to_owned() + sql)).await.expect(&msg); + let plan = dataframe.into_optimized_plan().unwrap(); + + let expected = vec![ + "Explain [plan_type:Utf8, plan:Utf8]", + " Projection: t1.t1_id, t1.t1_name, t2.t2_id [t1_id:UInt32;N, t1_name:Utf8;N, t2_id:UInt32;N]", + " Inner Join: CAST(t1.t1_id AS Int64) * Int64(4) = CAST(t2.t2_id AS Int64) Filter: CAST(t1.t1_int AS Int64) + Int64(10) < CAST(t2.t2_id AS Int64) AND t1.t1_int > t2.t2_int [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N, t2_id:UInt32;N, t2_int:UInt32;N]", + " TableScan: t1 projection=[t1_id, t1_name, t1_int] [t1_id:UInt32;N, t1_name:Utf8;N, t1_int:UInt32;N]", + " TableScan: t2 projection=[t2_id, t2_int] [t2_id:UInt32;N, t2_int:UInt32;N]", + ]; + + let formatted = plan.display_indent_schema().to_string(); + let actual: Vec<&str> = formatted.trim().lines().collect(); + assert_eq!( + expected, actual, + "\n\nexpected:\n\n{expected:#?}\nactual:\n\n{actual:#?}\n\n" + ); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 1c50c8ad0f5bb..28b056259f3d5 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -692,6 +692,40 @@ fn create_sort_merge_join_datatype_context() -> Result { Ok(ctx) } +fn create_muti_table_join_context() -> Result { + let ctx = SessionContext::with_config( + SessionConfig::new() + .with_repartition_joins(false) + .with_target_partitions(2) + .with_batch_size(4096), + ); + + let t1_schema = Arc::new(Schema::new(vec![ + Field::new("t1_id", DataType::UInt32, true), + Field::new("t1_name", DataType::Utf8, true), + Field::new("t1_int", DataType::UInt32, true), + ])); + let t1_data = RecordBatch::new_empty(t1_schema); + let t2_schema = Arc::new(Schema::new(vec![ + Field::new("t2_id", DataType::UInt32, true), + Field::new("t2_name", DataType::Utf8, true), + Field::new("t2_int", DataType::UInt32, true), + ])); + let t2_data = RecordBatch::new_empty(t2_schema); + let t3_schema = Arc::new(Schema::new(vec![ + Field::new("t3_id", DataType::UInt32, true), + Field::new("t3_name", DataType::Utf8, true), + Field::new("t3_int", DataType::UInt32, true), + ])); + let t3_data = RecordBatch::new_empty(t3_schema); + + ctx.register_batch("t1", t1_data)?; + ctx.register_batch("t2", t2_data)?; + ctx.register_batch("t3", t3_data)?; + + Ok(ctx) +} + fn get_tpch_table_schema(table: &str) -> Schema { match table { "customer" => Schema::new(vec![ diff --git a/datafusion/core/tests/sqllogictests/src/join.rs b/datafusion/core/tests/sqllogictests/src/join.rs new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/sqllogictests/test_files/join.slt b/datafusion/core/tests/sqllogictests/test_files/join.slt new file mode 100644 index 0000000000000..4366d99a9e673 --- /dev/null +++ b/datafusion/core/tests/sqllogictests/test_files/join.slt @@ -0,0 +1,44 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +########## +## JOIN Tests +########## + +statement ok +CREATE TABLE students(name TEXT, mark INT) AS VALUES +('Stuart', 28), +('Amina', 89), +('Christen', 50), +('Salma', 77), +('Samantha', 21); + +statement ok +CREATE TABLE grades(grade INT, min INT, max INT) AS VALUES +(1, 0, 14), +(2, 15, 35), +(3, 36, 55), +(4, 56, 79), +(5, 80, 100); + +# Regression test: https://github.com/apache/arrow-datafusion/issues/4844 +query I +SELECT s.*, g.grade FROM students s join grades g on s.mark between g.min and g.max WHERE grade > 2 ORDER BY s.mark DESC +---- +Amina 89 5 +Salma 77 4 +Christen 50 3 From e4c762145721bf986320187f4ff786cec237b655 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 10 Jan 2023 07:32:01 -0500 Subject: [PATCH 3/6] remove unused file --- datafusion/core/tests/sqllogictests/src/join.rs | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 datafusion/core/tests/sqllogictests/src/join.rs diff --git a/datafusion/core/tests/sqllogictests/src/join.rs b/datafusion/core/tests/sqllogictests/src/join.rs deleted file mode 100644 index e69de29bb2d1d..0000000000000 From 2ddf93134210803f74d8fef94b399a123ec659b0 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 10 Jan 2023 08:44:06 -0500 Subject: [PATCH 4/6] fix carg fmt --- datafusion/core/src/physical_plan/repartition.rs | 13 +++++++++---- datafusion/core/tests/sql/subqueries.rs | 2 +- .../optimizer/src/rewrite_disjunctive_predicate.rs | 2 +- 3 files changed, 11 insertions(+), 6 deletions(-) diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 451b0fba4b130..ccfb2a9e1ade4 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -399,10 +399,15 @@ impl ExecutionPlan for RepartitionExec { // now return stream for the specified *output* partition which will // read from the channel - let (_tx, rx, reservation) = state - .channels - .remove(&partition) - .expect("partition not used yet"); + let (_tx, rx, reservation) = state.channels.remove(&partition).expect( + format!( + "partition not used yet, partition:{:?}, channels:{:?}", + partition, + state.channels.keys() + ) + .as_str(), + ); + // .expect("partition not used yet"); Ok(Box::pin(RepartitionStream { num_input_partitions, num_input_partitions_processed: 0, diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index a32dec9a72d41..879992abf0f25 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -49,7 +49,7 @@ where c_acctbal < ( debug!("input:\n{}", dataframe.logical_plan().display_indent()); let plan = dataframe.into_optimized_plan().unwrap(); - let actual = format!("{}", plan.display_indent()); + let actual = format!("{}", plan.display_indent()); let expected = "Sort: customer.c_custkey ASC NULLS LAST\ \n Projection: customer.c_custkey\ diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index b571512468679..3e732025a6f8d 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -162,7 +162,7 @@ enum Predicate { impl PartialEq for Predicate { fn eq(&self, other: &Predicate) -> bool { match (self, other) { - // For BinaryExpr, A = B should be the same as B = A + // For BinaryExpr, A = B should be the same as B = A // when extract A = B in expr like (A = B and C > D) OR (B = A and C = E). (Predicate::Other { expr: l }, Predicate::Other { expr: r }) => { match (l.as_ref(), r.as_ref()) { From 5ed47b87bbc97d01b6b5999dca51ad4db15ec55f Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 10 Jan 2023 09:12:15 -0500 Subject: [PATCH 5/6] fix cargo clippy --- datafusion/core/src/physical_plan/repartition.rs | 5 +++++ datafusion/optimizer/src/eliminate_cross_join.rs | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index ccfb2a9e1ade4..004ca3825aea0 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -399,6 +399,11 @@ impl ExecutionPlan for RepartitionExec { // now return stream for the specified *output* partition which will // read from the channel + println!( + "partition: {:?}, channels:{:?}", + partition, + state.channels.keys() + ); let (_tx, rx, reservation) = state.channels.remove(&partition).expect( format!( "partition not used yet, partition:{:?}, channels:{:?}", diff --git a/datafusion/optimizer/src/eliminate_cross_join.rs b/datafusion/optimizer/src/eliminate_cross_join.rs index 7b88823a66a0b..b45fa8fe71d37 100644 --- a/datafusion/optimizer/src/eliminate_cross_join.rs +++ b/datafusion/optimizer/src/eliminate_cross_join.rs @@ -23,7 +23,7 @@ use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DataFusionError, Result}; use datafusion_expr::expr::{BinaryExpr, Expr}; use datafusion_expr::logical_plan::{ - CrossJoin, Filter, Join, JoinConstraint, JoinType, LogicalPlan, Projection, + CrossJoin, Filter, JoinType, LogicalPlan, Projection, }; use datafusion_expr::utils::{find_valid_equijoin_key_pair, is_valid_join_predicate}; use datafusion_expr::{build_join_schema, LogicalPlanBuilder, Operator}; From a7b0e2aaa3ae1b6f0d066c9243899974a223c461 Mon Sep 17 00:00:00 2001 From: ygf11 Date: Tue, 10 Jan 2023 09:35:59 -0500 Subject: [PATCH 6/6] remove comment --- .../core/src/physical_plan/repartition.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/physical_plan/repartition.rs b/datafusion/core/src/physical_plan/repartition.rs index 004ca3825aea0..451b0fba4b130 100644 --- a/datafusion/core/src/physical_plan/repartition.rs +++ b/datafusion/core/src/physical_plan/repartition.rs @@ -399,20 +399,10 @@ impl ExecutionPlan for RepartitionExec { // now return stream for the specified *output* partition which will // read from the channel - println!( - "partition: {:?}, channels:{:?}", - partition, - state.channels.keys() - ); - let (_tx, rx, reservation) = state.channels.remove(&partition).expect( - format!( - "partition not used yet, partition:{:?}, channels:{:?}", - partition, - state.channels.keys() - ) - .as_str(), - ); - // .expect("partition not used yet"); + let (_tx, rx, reservation) = state + .channels + .remove(&partition) + .expect("partition not used yet"); Ok(Box::pin(RepartitionStream { num_input_partitions, num_input_partitions_processed: 0,