Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -355,26 +355,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
bottomUp(new EliminateJoinByFK()),
topDown(new EliminateJoinByUnique())
),

// this rule should be after topic "Column pruning and infer predicate"
topic("Join pull up",
topDown(
new EliminateFilter(),
new PushDownFilterThroughProject(),
new MergeProjects()
),
topDown(
new PullUpJoinFromUnionAll()
),
custom(RuleType.COLUMN_PRUNING, ColumnPruning::new),
bottomUp(RuleSet.PUSH_DOWN_FILTERS),
custom(RuleType.ELIMINATE_UNNECESSARY_PROJECT, EliminateUnnecessaryProject::new)
),

// this rule should be invoked after topic "Join pull up"
topic("eliminate Aggregate according to fd items",
topDown(new EliminateGroupByKey()),
topDown(new PushDownAggThroughJoinOnPkFk())
topDown(new PushDownAggThroughJoinOnPkFk()),
topDown(new PullUpJoinFromUnionAll())
),

topic("Limit optimization",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ public enum RuleType {

// split limit
SPLIT_LIMIT(RuleTypeClass.REWRITE),
PULL_UP_JOIN_FROM_UNIONALL(RuleTypeClass.REWRITE),
PULL_UP_JOIN_FROM_UNION_ALL(RuleTypeClass.REWRITE),
// limit push down
PUSH_LIMIT_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import org.junit.jupiter.api.Test;

class PullUpJoinFromUnionTest extends TestWithFeService implements MemoPatternMatchSupported {
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTables(
"CREATE TABLE IF NOT EXISTS t1 (\n"
+ " id int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n",
"CREATE TABLE IF NOT EXISTS t2 (\n"
+ " id int not null,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n",
"CREATE TABLE IF NOT EXISTS t3 (\n"
+ " id int,\n"
+ " name char\n"
+ ")\n"
+ "DUPLICATE KEY(id)\n"
+ "DISTRIBUTED BY HASH(id) BUCKETS 10\n"
+ "PROPERTIES (\"replication_num\" = \"1\")\n"
);
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
}

@Test
void testSimple() {
String sql = "select * from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select * from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testProject() {
String sql = "select t2.id from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));

sql = "select t2.id, t1.name from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id, t1.name from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testConstant() {
String sql = "select t2.id, t1.name, 1 as id1 from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id, t1.name, 2 as id2 from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testComplexProject() {
String sql = "select t2.id + 1, t1.name + 1, 1 as id1 from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t3.id + 1, t1.name + 1, 2 as id2 from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalUnion(), any()));
}

@Test
void testMissJoinSlot() {
String sql = "select t1.name + 1, 1 as id1 from t1 join t2 on t1.id = t2.id "
+ "union all "
+ "select t1.name + 1, 2 as id2 from t1 join t3 on t1.id = t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalUnion(), any()));
}

@Test
void testFilter() {
String sql = "select * from t1 join t2 on t1.id = t2.id where t1.name = '' "
+ "union all "
+ "select * from t1 join t3 on t1.id = t3.id where t1.name = '' ;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));

sql = "select t2.id from t1 join t2 on t1.id = t2.id where t1.name = '' "
+ "union all "
+ "select t3.id from t1 join t3 on t1.id = t3.id where t1.name = '' ;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testMultipleJoinConditions() {
String sql = "select * from t1 join t2 on t1.id = t2.id and t1.name = t2.name "
+ "union all "
+ "select * from t1 join t3 on t1.id = t3.id and t1.name = t3.name;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testNonEqualityJoinConditions() {
String sql = "select * from t1 join t2 on t1.id < t2.id "
+ "union all "
+ "select * from t1 join t3 on t1.id < t3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.nonMatch(logicalJoin(logicalProject(logicalUnion()), any()));
}

@Test
void testSubqueries() {
String sql = "select * from t1 join (select * from t2 where t2.id > 10) s2 on t1.id = s2.id "
+ "union all "
+ "select * from t1 join (select * from t3 where t3.id > 10) s3 on t1.id = s3.id;";
PlanChecker.from(connectContext)
.analyze(sql)
.rewrite()
.matches(logicalJoin(logicalProject(logicalUnion()), any()));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------hashJoin[INNER_JOIN shuffle] hashCondition=((PULL_UP_UNIFIED_OUTPUT_ALIAS = customer.c_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk,ws_bill_customer_sk]
------hashJoin[INNER_JOIN shuffle] hashCondition=((ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF2 c_customer_sk->[ss_customer_sk,ws_bill_customer_sk]
--------PhysicalProject
----------PhysicalUnion
------------PhysicalProject
Expand Down
62 changes: 26 additions & 36 deletions regression-test/data/nereids_hint_tpcds_p0/shape/query14.out
Original file line number Diff line number Diff line change
Expand Up @@ -55,31 +55,21 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------PhysicalDistribute[DistributionSpecGather]
----------hashAgg[LOCAL]
------------PhysicalProject
--------------PhysicalUnion
----------------PhysicalDistribute[DistributionSpecExecutionAny]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF9 d_date_sk->[ss_sold_date_sk]
--------------hashJoin[INNER_JOIN broadcast] hashCondition=((ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF9 d_date_sk->[cs_sold_date_sk,ss_sold_date_sk,ws_sold_date_sk]
----------------PhysicalProject
------------------PhysicalUnion
--------------------PhysicalDistribute[DistributionSpecExecutionAny]
----------------------PhysicalProject
------------------------PhysicalOlapScan[store_sales] apply RFs: RF9
--------------------PhysicalDistribute[DistributionSpecExecutionAny]
----------------------PhysicalProject
------------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------------PhysicalOlapScan[date_dim]
----------------PhysicalDistribute[DistributionSpecExecutionAny]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF10 d_date_sk->[cs_sold_date_sk]
----------------------PhysicalProject
------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF10
----------------------PhysicalProject
------------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------------PhysicalOlapScan[date_dim]
----------------PhysicalDistribute[DistributionSpecExecutionAny]
------------------PhysicalProject
--------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF11 d_date_sk->[ws_sold_date_sk]
------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF9
--------------------PhysicalDistribute[DistributionSpecExecutionAny]
----------------------PhysicalProject
------------------------PhysicalOlapScan[web_sales] apply RFs: RF11
----------------------PhysicalProject
------------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------------PhysicalOlapScan[date_dim]
------------------------PhysicalOlapScan[web_sales] apply RFs: RF9
----------------PhysicalProject
------------------filter((date_dim.d_year <= 2001) and (date_dim.d_year >= 1999))
--------------------PhysicalOlapScan[date_dim]
----PhysicalResultSink
------PhysicalTopN[MERGE_SORT]
--------PhysicalDistribute[DistributionSpecGather]
Expand All @@ -97,16 +87,16 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------PhysicalDistribute[DistributionSpecHash]
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF14 i_item_sk->[ss_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF13 ss_item_sk->[ss_item_sk]
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF12 i_item_sk->[ss_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((store_sales.ss_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF11 ss_item_sk->[ss_item_sk]
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF12 d_date_sk->[ss_sold_date_sk]
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF10 d_date_sk->[ss_sold_date_sk]
----------------------------------------------PhysicalProject
------------------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF12 RF13 RF14
------------------------------------------------PhysicalOlapScan[store_sales] apply RFs: RF10 RF11 RF12
----------------------------------------------PhysicalProject
------------------------------------------------filter((date_dim.d_moy = 11) and (date_dim.d_year = 2001))
--------------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF14
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF12
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
Expand All @@ -120,16 +110,16 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------PhysicalDistribute[DistributionSpecHash]
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF17 i_item_sk->[cs_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF16 ss_item_sk->[cs_item_sk]
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF15 i_item_sk->[cs_item_sk,ss_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((catalog_sales.cs_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF14 ss_item_sk->[cs_item_sk]
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF15 d_date_sk->[cs_sold_date_sk]
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF13 d_date_sk->[cs_sold_date_sk]
----------------------------------------------PhysicalProject
------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF15 RF16 RF17
------------------------------------------------PhysicalOlapScan[catalog_sales] apply RFs: RF13 RF14 RF15
----------------------------------------------PhysicalProject
------------------------------------------------filter((date_dim.d_moy = 11) and (date_dim.d_year = 2001))
--------------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF17
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF15
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
Expand All @@ -143,16 +133,16 @@ PhysicalCteAnchor ( cteId=CTEId#0 )
--------------------------------PhysicalDistribute[DistributionSpecHash]
----------------------------------hashAgg[LOCAL]
------------------------------------PhysicalProject
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF20 i_item_sk->[ss_item_sk,ws_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF19 ss_item_sk->[ws_item_sk]
--------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = item.i_item_sk)) otherCondition=() build RFs:RF18 i_item_sk->[ss_item_sk,ws_item_sk]
----------------------------------------hashJoin[LEFT_SEMI_JOIN broadcast] hashCondition=((web_sales.ws_item_sk = cross_items.ss_item_sk)) otherCondition=() build RFs:RF17 ss_item_sk->[ws_item_sk]
------------------------------------------PhysicalProject
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF18 d_date_sk->[ws_sold_date_sk]
--------------------------------------------hashJoin[INNER_JOIN broadcast] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF16 d_date_sk->[ws_sold_date_sk]
----------------------------------------------PhysicalProject
------------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF18 RF19 RF20
------------------------------------------------PhysicalOlapScan[web_sales] apply RFs: RF16 RF17 RF18
----------------------------------------------PhysicalProject
------------------------------------------------filter((date_dim.d_moy = 11) and (date_dim.d_year = 2001))
--------------------------------------------------PhysicalOlapScan[date_dim]
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF20
------------------------------------------PhysicalCteConsumer ( cteId=CTEId#0 ) apply RFs: RF18
----------------------------------------PhysicalProject
------------------------------------------PhysicalOlapScan[item]
----------------------------PhysicalProject
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalProject
------hashJoin[INNER_JOIN shuffle] hashCondition=((PULL_UP_UNIFIED_OUTPUT_ALIAS = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk,ss_customer_sk,ws_bill_customer_sk]
------hashJoin[INNER_JOIN shuffle] hashCondition=((ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk,ss_customer_sk,ws_bill_customer_sk]
--------PhysicalProject
----------PhysicalUnion
------------PhysicalProject
Expand Down
Loading