From 7fe960dd3fb320f2245c27051340faad6c66f216 Mon Sep 17 00:00:00 2001 From: xiejiann Date: Fri, 7 Jun 2024 10:53:21 +0800 Subject: [PATCH 1/9] push agg throght foreign key --- .../doris/nereids/jobs/executor/Rewriter.java | 7 +- .../doris/nereids/properties/FuncDeps.java | 4 + .../apache/doris/nereids/rules/RuleType.java | 2 +- .../rewrite/PushDownAggThroughJoinByFk.java | 134 ++++++++++++++++++ .../PushDownAggThroughJoinByFkTest.java | 63 ++++++++ 5 files changed, 207 insertions(+), 3 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index bd201ed004d9d2..babb4f5b951089 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -110,6 +110,7 @@ import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin; +import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinByFk; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide; import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; @@ -348,8 +349,10 @@ public class Rewriter extends AbstractBatchJobExecutor { ), // this rule should be invoked after topic "Join pull up" - topic("eliminate group by keys according to fd items", - topDown(new EliminateGroupByKey()) + topic("eliminate Aggregate according to fd items", + topDown(new EliminateGroupByKey()), + topDown(new PushDownAggThroughJoinByFk()), + custom(RuleType.COLUMN_PRUNING, ColumnPruning::new) ), topic("Limit optimization", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java index 6c1b302d7dc11d..52fa91b3df39c7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java @@ -144,6 +144,10 @@ public boolean isFuncDeps(Set dominate, Set dependency) { return items.contains(new FuncDepsItem(dominate, dependency)); } + public Set calBinaryDependencies(Slot slot) { + return new HashSet<>(); + } + @Override public String toString() { return items.toString(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index d6895b4121de78..ecfb939f9f16d3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -186,7 +186,7 @@ public enum RuleType { PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE), PUSH_DOWN_AGG_THROUGH_JOIN(RuleTypeClass.REWRITE), - + PUSH_DOWN_AGG_THROUGH_FK_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE), LOGICAL_SEMI_JOIN_COMMUTE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java new file mode 100644 index 00000000000000..8c05632ecc8061 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java @@ -0,0 +1,134 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.properties.DataTrait; +import org.apache.doris.nereids.properties.FuncDeps; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Project; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.JoinUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import org.apache.thrift.annotation.Nullable; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Agg(group by fk) + * | + * Join(pk = fk) + * / \ + * pk fk + * ======> + * Join(pk = fk) + * / \ + * | Agg(group by fk) + * | | + * pk fk + */ +public class PushDownAggThroughJoinByFk implements RewriteRuleFactory { + @Override + public List buildRules() { + return ImmutableList.of( + logicalAggregate( + innerLogicalJoin() + .when(j -> j.getJoinType().isInnerJoin() + && !j.isMarkJoin() + && j.getOtherJoinConjuncts().isEmpty())) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance)) + .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child())) + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_FK_JOIN), + logicalAggregate( + logicalProject( + innerLogicalJoin() + .when(j -> j.getJoinType().isInnerJoin() + && !j.isMarkJoin() + && j.getOtherJoinConjuncts().isEmpty())) + .when(Project::isAllSlots)) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance)) + .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child().child())) + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_FK_JOIN) + ); + } + + private @Nullable Plan pushAgg(LogicalAggregate agg, LogicalJoin join) { + Plan foreign = tryExtractForeign(join); + if (foreign == null) { + return null; + } + LogicalAggregate newAgg = tryGroupByForeign(agg, foreign); + if (newAgg == null) { + return null; + } + + if (join.left() == foreign) { + return join.withChildren(agg, join.right()); + } else if (join.right() == foreign) { + return join.withChildren(join.left(), agg); + } + return null; + } + + private @Nullable LogicalAggregate tryGroupByForeign(LogicalAggregate agg, Plan foreign) { + Set groupBySlots = new HashSet<>(); + for (Expression expr : agg.getGroupByExpressions()) { + groupBySlots.addAll(expr.getInputSlots()); + } + if (foreign.getOutputSet().containsAll(groupBySlots)) { + return agg; + } + Set foreignOutput = foreign.getOutputSet(); + Set primarySlots = Sets.difference(groupBySlots, foreignOutput); + DataTrait dataTrait = agg.child().getLogicalProperties().getTrait(); + FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(groupBySlots, foreign.getOutputSet())); + for (Slot slot : primarySlots) { + Set slots = funcDeps.calBinaryDependencies(slot); + Set foreignSlots = Sets.intersection(slots, foreignOutput); + if (foreignSlots.isEmpty()) { + return null; + } + groupBySlots.remove(slot); + groupBySlots.add(foreignSlots.iterator().next()); + } + Set newOutput = Sets.intersection(agg.getOutputSet(), foreignOutput); + LogicalAggregate newAgg = + agg.withGroupByAndOutput(ImmutableList.copyOf(groupBySlots), ImmutableList.copyOf(newOutput)); + return (LogicalAggregate) newAgg.withChildren(foreign); + } + + private @Nullable Plan tryExtractForeign(LogicalJoin join) { + Plan foreign; + if (JoinUtils.canEliminateByFk(join, join.left(), join.right())) { + foreign = join.right(); + } else if (JoinUtils.canEliminateByFk(join, join.right(), join.left())) { + foreign = join.left(); + } else { + return null; + } + return foreign; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java new file mode 100644 index 00000000000000..db3b942d4ffb91 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java @@ -0,0 +1,63 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +class PushDownAggThroughJoinByFkTest extends TestWithFeService implements MemoPatternMatchSupported { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + connectContext.setDatabase("default_cluster:test"); + createTables( + "CREATE TABLE IF NOT EXISTS pri (\n" + + " id1 int not null\n" + + ")\n" + + "DUPLICATE KEY(id1)\n" + + "DISTRIBUTED BY HASH(id1) BUCKETS 10\n" + + "PROPERTIES (\"replication_num\" = \"1\")\n", + "CREATE TABLE IF NOT EXISTS foreign_not_null (\n" + + " id2 int not null\n" + + ")\n" + + "DUPLICATE KEY(id2)\n" + + "DISTRIBUTED BY HASH(id2) BUCKETS 10\n" + + "PROPERTIES (\"replication_num\" = \"1\")\n", + "CREATE TABLE IF NOT EXISTS foreign_null (\n" + + " id3 int\n" + + ")\n" + + "DUPLICATE KEY(id3)\n" + + "DISTRIBUTED BY HASH(id3) BUCKETS 10\n" + + "PROPERTIES (\"replication_num\" = \"1\")\n" + ); + addConstraint("Alter table pri add constraint pk primary key (id1)"); + addConstraint("Alter table foreign_not_null add constraint f_not_null foreign key (id2)\n" + + "references pri(id1)"); + addConstraint("Alter table foreign_null add constraint f_not_null foreign key (id3)\n" + + "references pri(id1)"); + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + } + + @Test + void test1() { + Assertions.assertTrue(true); + } +} \ No newline at end of file From be9a002551aa082621c96e04e63dae6e127be00f Mon Sep 17 00:00:00 2001 From: xiejiann Date: Fri, 7 Jun 2024 14:20:41 +0800 Subject: [PATCH 2/9] add push down agg through join with foreign key --- .../doris/nereids/jobs/executor/Rewriter.java | 3 +- .../doris/nereids/properties/FuncDeps.java | 17 +++++- .../rewrite/PushDownAggThroughJoinByFk.java | 38 ++++++++---- .../PushDownAggThroughJoinByFkTest.java | 59 ++++++++++++++++--- 4 files changed, 93 insertions(+), 24 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index babb4f5b951089..d9071a184c9c40 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -351,8 +351,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // this rule should be invoked after topic "Join pull up" topic("eliminate Aggregate according to fd items", topDown(new EliminateGroupByKey()), - topDown(new PushDownAggThroughJoinByFk()), - custom(RuleType.COLUMN_PRUNING, ColumnPruning::new) + topDown(new PushDownAggThroughJoinByFk()) ), topic("Limit optimization", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java index 52fa91b3df39c7..f320e0029f89d2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java @@ -144,8 +144,21 @@ public boolean isFuncDeps(Set dominate, Set dependency) { return items.contains(new FuncDepsItem(dominate, dependency)); } - public Set calBinaryDependencies(Slot slot) { - return new HashSet<>(); + /** + * find the cycle of dependencies + */ + public Set> calBinaryDependencies(Set slotSet) { + Set> binaryDeps = new HashSet<>(); + Set> dependencies = edges.get(slotSet); + if (dependencies == null) { + return binaryDeps; + } + for (Set other : dependencies) { + if (edges.get(other).contains(slotSet)) { + binaryDeps.add(other); + } + } + return binaryDeps; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java index 8c05632ecc8061..419ce84814c079 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.util.JoinUtils; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; import org.apache.thrift.annotation.Nullable; @@ -38,7 +39,8 @@ import java.util.Set; /** - * Agg(group by fk) + * Push down agg through join with foreign key: + * Agg(group by fk/pk) * | * Join(pk = fk) * / \ @@ -59,7 +61,8 @@ public List buildRules() { .when(j -> j.getJoinType().isInnerJoin() && !j.isMarkJoin() && j.getOtherJoinConjuncts().isEmpty())) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance)) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance) + && agg.getOutputExpressions().stream().allMatch(Slot.class::isInstance)) .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child())) .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_FK_JOIN), logicalAggregate( @@ -80,24 +83,29 @@ public List buildRules() { if (foreign == null) { return null; } - LogicalAggregate newAgg = tryGroupByForeign(agg, foreign); + Set foreignKey = Sets.intersection(join.getEqualSlots().getAllItemSet(), foreign.getOutputSet()); + LogicalAggregate newAgg = tryGroupByForeign(agg, foreign, foreignKey); if (newAgg == null) { return null; } if (join.left() == foreign) { - return join.withChildren(agg, join.right()); + return join.withChildren(newAgg, join.right()); } else if (join.right() == foreign) { - return join.withChildren(join.left(), agg); + return join.withChildren(join.left(), newAgg); } return null; } - private @Nullable LogicalAggregate tryGroupByForeign(LogicalAggregate agg, Plan foreign) { + private @Nullable LogicalAggregate tryGroupByForeign( + LogicalAggregate agg, Plan foreign, Set foreignKey) { Set groupBySlots = new HashSet<>(); for (Expression expr : agg.getGroupByExpressions()) { groupBySlots.addAll(expr.getInputSlots()); } + if (groupBySlots.containsAll(foreignKey)) { + return null; + } if (foreign.getOutputSet().containsAll(groupBySlots)) { return agg; } @@ -105,18 +113,22 @@ public List buildRules() { Set primarySlots = Sets.difference(groupBySlots, foreignOutput); DataTrait dataTrait = agg.child().getLogicalProperties().getTrait(); FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(groupBySlots, foreign.getOutputSet())); + Set newGroupBySlots = new HashSet<>(groupBySlots); for (Slot slot : primarySlots) { - Set slots = funcDeps.calBinaryDependencies(slot); - Set foreignSlots = Sets.intersection(slots, foreignOutput); - if (foreignSlots.isEmpty()) { + Set> replacedSlotSets = funcDeps.calBinaryDependencies(ImmutableSet.of(slot)); + for (Set replacedSlots : replacedSlotSets) { + if (foreignOutput.containsAll(replacedSlots)) { + newGroupBySlots.remove(slot); + newGroupBySlots.addAll(replacedSlots); + break; + } + } + if (newGroupBySlots.contains(slot)) { return null; } - groupBySlots.remove(slot); - groupBySlots.add(foreignSlots.iterator().next()); } - Set newOutput = Sets.intersection(agg.getOutputSet(), foreignOutput); LogicalAggregate newAgg = - agg.withGroupByAndOutput(ImmutableList.copyOf(groupBySlots), ImmutableList.copyOf(newOutput)); + agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newGroupBySlots)); return (LogicalAggregate) newAgg.withChildren(foreign); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java index db3b942d4ffb91..7103de93c9cae4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java @@ -18,9 +18,9 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.utframe.TestWithFeService; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; class PushDownAggThroughJoinByFkTest extends TestWithFeService implements MemoPatternMatchSupported { @@ -30,19 +30,22 @@ protected void runBeforeAll() throws Exception { connectContext.setDatabase("default_cluster:test"); createTables( "CREATE TABLE IF NOT EXISTS pri (\n" - + " id1 int not null\n" + + " id1 int not null,\n" + + " name char\n" + ")\n" + "DUPLICATE KEY(id1)\n" + "DISTRIBUTED BY HASH(id1) BUCKETS 10\n" + "PROPERTIES (\"replication_num\" = \"1\")\n", "CREATE TABLE IF NOT EXISTS foreign_not_null (\n" - + " id2 int not null\n" + + " id2 int not null,\n" + + " name char\n" + ")\n" + "DUPLICATE KEY(id2)\n" + "DISTRIBUTED BY HASH(id2) BUCKETS 10\n" + "PROPERTIES (\"replication_num\" = \"1\")\n", "CREATE TABLE IF NOT EXISTS foreign_null (\n" - + " id3 int\n" + + " id3 int,\n" + + " name char\n" + ")\n" + "DUPLICATE KEY(id3)\n" + "DISTRIBUTED BY HASH(id3) BUCKETS 10\n" @@ -57,7 +60,49 @@ protected void runBeforeAll() throws Exception { } @Test - void test1() { - Assertions.assertTrue(true); + void testGroupByFk() { + String sql = "select pri.id1 from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalAggregate())) + .printlnTree(); } -} \ No newline at end of file + + @Test + void testGroupByFkAndOther() { + String sql = "select pri.id1 from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1, foreign_not_null.name"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalProject(logicalAggregate()))) + .printlnTree(); + sql = "select pri.id1 from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1, pri.name"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalAggregate())) + .printlnTree(); + sql = "select pri.id1 from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1, pri.name, foreign_not_null.name"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalProject(logicalAggregate()))) + .printlnTree(); + } + + @Test + void testGroupByFkWithAggFunc() { + String sql = "select count(pri.id1) from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalAggregate(logicalProject(logicalJoin()))) + .printlnTree(); + } +} From 00dbd2be12468ee9ae01da5922279d63a285b236 Mon Sep 17 00:00:00 2001 From: xiejiann Date: Mon, 17 Jun 2024 17:56:37 +0800 Subject: [PATCH 3/9] change rule type --- .../org/apache/doris/nereids/jobs/executor/Rewriter.java | 4 ++-- .../main/java/org/apache/doris/nereids/rules/RuleType.java | 2 +- ...roughJoinByFk.java => PushDownAggThroughJoinOnPkFk.java} | 6 +++--- ...nByFkTest.java => PushDownAggThroughJoinOnPkFkTest.java} | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) rename fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/{PushDownAggThroughJoinByFk.java => PushDownAggThroughJoinOnPkFk.java} (96%) rename fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/{PushDownAggThroughJoinByFkTest.java => PushDownAggThroughJoinOnPkFkTest.java} (97%) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index d9071a184c9c40..9e698e1f6eaf1d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -110,7 +110,7 @@ import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoJdbcScan; import org.apache.doris.nereids.rules.rewrite.PushConjunctsIntoOdbcScan; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoin; -import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinByFk; +import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOnPkFk; import org.apache.doris.nereids.rules.rewrite.PushDownAggThroughJoinOneSide; import org.apache.doris.nereids.rules.rewrite.PushDownDistinctThroughJoin; import org.apache.doris.nereids.rules.rewrite.PushDownFilterThroughProject; @@ -351,7 +351,7 @@ public class Rewriter extends AbstractBatchJobExecutor { // this rule should be invoked after topic "Join pull up" topic("eliminate Aggregate according to fd items", topDown(new EliminateGroupByKey()), - topDown(new PushDownAggThroughJoinByFk()) + topDown(new PushDownAggThroughJoinOnPkFk()) ), topic("Limit optimization", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index ecfb939f9f16d3..dcd36420c7a56a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -186,7 +186,7 @@ public enum RuleType { PUSH_DOWN_AGG_THROUGH_JOIN_ONE_SIDE(RuleTypeClass.REWRITE), PUSH_DOWN_AGG_THROUGH_JOIN(RuleTypeClass.REWRITE), - PUSH_DOWN_AGG_THROUGH_FK_JOIN(RuleTypeClass.REWRITE), + PUSH_DOWN_AGG_THROUGH_JOIN_ON_PKFK(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN(RuleTypeClass.REWRITE), TRANSPOSE_LOGICAL_SEMI_JOIN_LOGICAL_JOIN_PROJECT(RuleTypeClass.REWRITE), LOGICAL_SEMI_JOIN_COMMUTE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java similarity index 96% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index 419ce84814c079..cb15e2a5bf11db 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -52,7 +52,7 @@ * | | * pk fk */ -public class PushDownAggThroughJoinByFk implements RewriteRuleFactory { +public class PushDownAggThroughJoinOnPkFk implements RewriteRuleFactory { @Override public List buildRules() { return ImmutableList.of( @@ -64,7 +64,7 @@ public List buildRules() { .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance) && agg.getOutputExpressions().stream().allMatch(Slot.class::isInstance)) .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child())) - .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_FK_JOIN), + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ON_PKFK), logicalAggregate( logicalProject( innerLogicalJoin() @@ -74,7 +74,7 @@ public List buildRules() { .when(Project::isAllSlots)) .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance)) .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child().child())) - .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_FK_JOIN) + .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ON_PKFK) ); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java similarity index 97% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java index 7103de93c9cae4..c39b932b454d9c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinByFkTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java @@ -23,7 +23,7 @@ import org.junit.jupiter.api.Test; -class PushDownAggThroughJoinByFkTest extends TestWithFeService implements MemoPatternMatchSupported { +class PushDownAggThroughJoinOnPkFkTest extends TestWithFeService implements MemoPatternMatchSupported { @Override protected void runBeforeAll() throws Exception { createDatabase("test"); From 2c7d9f233268a7b6dff9b4294ced0cec00a20105 Mon Sep 17 00:00:00 2001 From: xiejiann Date: Tue, 18 Jun 2024 15:06:47 +0800 Subject: [PATCH 4/9] fix agg func --- .../rewrite/PushDownAggThroughJoinOnPkFk.java | 74 +++++++++++++++---- .../PushDownAggThroughJoinOnPkFkTest.java | 31 +++++++- 2 files changed, 88 insertions(+), 17 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index cb15e2a5bf11db..34ee06c75813ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -21,8 +21,11 @@ import org.apache.doris.nereids.properties.FuncDeps; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Project; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -34,8 +37,11 @@ import com.google.common.collect.Sets; import org.apache.thrift.annotation.Nullable; +import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; /** @@ -61,8 +67,7 @@ public List buildRules() { .when(j -> j.getJoinType().isInnerJoin() && !j.isMarkJoin() && j.getOtherJoinConjuncts().isEmpty())) - .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance) - && agg.getOutputExpressions().stream().allMatch(Slot.class::isInstance)) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance)) .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child())) .toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ON_PKFK), logicalAggregate( @@ -83,8 +88,7 @@ public List buildRules() { if (foreign == null) { return null; } - Set foreignKey = Sets.intersection(join.getEqualSlots().getAllItemSet(), foreign.getOutputSet()); - LogicalAggregate newAgg = tryGroupByForeign(agg, foreign, foreignKey); + LogicalAggregate newAgg = tryGroupByForeign(agg, foreign); if (newAgg == null) { return null; } @@ -97,15 +101,51 @@ public List buildRules() { return null; } + private @Nullable Set constructNewGroupBy(LogicalAggregate agg, Set foreignOutput, + Map primaryToForeignBiDeps) { + Set newGroupBySlots = new HashSet<>(); + for (Expression expression : agg.getGroupByExpressions()) { + if (!(expression instanceof Slot)) { + return null; + } + if (!foreignOutput.contains((Slot) expression) + && !primaryToForeignBiDeps.containsKey((Slot) expression)) { + return null; + } + expression = primaryToForeignBiDeps.getOrDefault(expression, (Slot) expression); + newGroupBySlots.add(expression); + } + return newGroupBySlots; + } + + private @Nullable List constructNewOutput(LogicalAggregate agg, Set foreignOutput, + Map primaryToForeignBiDeps) { + List newOutput = new ArrayList<>(); + for (NamedExpression expression : agg.getOutputExpressions()) { + if (expression instanceof Slot && primaryToForeignBiDeps.containsKey(expression)) { + expression = primaryToForeignBiDeps.getOrDefault(expression, expression.toSlot()); + } + if (expression instanceof Alias && expression.child(0) instanceof Count) { + expression = (NamedExpression) expression.rewriteUp(e -> + e instanceof Slot + ? primaryToForeignBiDeps.getOrDefault((Slot) e, (Slot) e) + : e); + } + if (!(expression instanceof Slot) + && !foreignOutput.containsAll(expression.getInputSlots())) { + return null; + } + newOutput.add(expression); + } + return newOutput; + } + private @Nullable LogicalAggregate tryGroupByForeign( - LogicalAggregate agg, Plan foreign, Set foreignKey) { + LogicalAggregate agg, Plan foreign) { Set groupBySlots = new HashSet<>(); for (Expression expr : agg.getGroupByExpressions()) { groupBySlots.addAll(expr.getInputSlots()); } - if (groupBySlots.containsAll(foreignKey)) { - return null; - } if (foreign.getOutputSet().containsAll(groupBySlots)) { return agg; } @@ -113,22 +153,24 @@ public List buildRules() { Set primarySlots = Sets.difference(groupBySlots, foreignOutput); DataTrait dataTrait = agg.child().getLogicalProperties().getTrait(); FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(groupBySlots, foreign.getOutputSet())); - Set newGroupBySlots = new HashSet<>(groupBySlots); + HashMap primaryToForeignBiDeps = new HashMap<>(); for (Slot slot : primarySlots) { Set> replacedSlotSets = funcDeps.calBinaryDependencies(ImmutableSet.of(slot)); for (Set replacedSlots : replacedSlotSets) { - if (foreignOutput.containsAll(replacedSlots)) { - newGroupBySlots.remove(slot); - newGroupBySlots.addAll(replacedSlots); + if (foreignOutput.containsAll(replacedSlots) && replacedSlots.size() == 1) { + primaryToForeignBiDeps.put(slot, replacedSlots.iterator().next()); break; } } - if (newGroupBySlots.contains(slot)) { - return null; - } + } + + Set newGroupBySlots = constructNewGroupBy(agg, foreignOutput, primaryToForeignBiDeps); + List newOutput = constructNewOutput(agg, foreignOutput, primaryToForeignBiDeps); + if (newGroupBySlots == null || newOutput == null) { + return null; } LogicalAggregate newAgg = - agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newGroupBySlots)); + agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newOutput)); return (LogicalAggregate) newAgg.withChildren(foreign); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java index c39b932b454d9c..b2f01c780be289 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java @@ -96,9 +96,38 @@ void testGroupByFkAndOther() { } @Test - void testGroupByFkWithAggFunc() { + void testGroupByFkWithCount() { String sql = "select count(pri.id1) from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + "group by foreign_not_null.id2, pri.id1"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalAggregate())) + .printlnTree(); + sql = "select count(foreign_not_null.id2) from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalAggregate())) + .printlnTree(); + } + + @Test + void testGroupByFkWithForeigAgg() { + String sql = "select sum(foreign_not_null.id2) from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(any(), logicalAggregate())) + .printlnTree(); + } + + @Test + void testGroupByFkWithPrimaryAgg() { + String sql = "select sum(pri.id1) from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1"; PlanChecker.from(connectContext) .analyze(sql) .rewrite() From fa120c189917b916e920d1c728896a1a0ea3983d Mon Sep 17 00:00:00 2001 From: xiejiann Date: Fri, 21 Jun 2024 08:49:05 +0800 Subject: [PATCH 5/9] process complex join --- .../rewrite/PushDownAggThroughJoinOnPkFk.java | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index 34ee06c75813ef..e0e43e1517b3a6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -30,11 +30,13 @@ import org.apache.doris.nereids.trees.plans.algebra.Project; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.JoinUtils; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; +import org.apache.hadoop.yarn.webapp.hamlet2.Hamlet.P; import org.apache.thrift.annotation.Nullable; import java.util.ArrayList; @@ -83,6 +85,10 @@ public List buildRules() { ); } + private List> collectContiguousInnerJoin() { + + } + private @Nullable Plan pushAgg(LogicalAggregate agg, LogicalJoin join) { Plan foreign = tryExtractForeign(join); if (foreign == null) { @@ -185,4 +191,23 @@ public List buildRules() { } return foreign; } + + static class InnerJoinCluster { + private List> contiguousInnerJoins = new ArrayList<>(); + private List leaf = new ArrayList<>(); + + void collectContiguousInnerJoins(Plan plan) { + if (plan instanceof LogicalProject) { + boolean isSlotProject = ((LogicalProject) plan).getProjects().stream() + .allMatch(Slot.class::isInstance); + if (!isSlotProject) { + leaf.add(plan); + return; + } + } + for (Plan child : plan.children()) { + collectContiguousInnerJoins(child); + } + } + } } From e70a6e41276ebe6f6b3e2e5693889b7656e6dc56 Mon Sep 17 00:00:00 2001 From: xiejiann Date: Fri, 21 Jun 2024 11:28:55 +0800 Subject: [PATCH 6/9] process complex join --- .../doris/nereids/properties/FuncDeps.java | 24 +- .../rewrite/PushDownAggThroughJoinOnPkFk.java | 248 +++++++++++++----- .../PushDownAggThroughJoinOnPkFkTest.java | 35 ++- .../shape/query38.out | 51 ++-- .../shape/query87.out | 51 ++-- .../noStatsRfPrune/query38.out | 51 ++-- .../noStatsRfPrune/query87.out | 51 ++-- .../no_stats_shape/query38.out | 51 ++-- .../no_stats_shape/query87.out | 51 ++-- .../rf_prune/query38.out | 51 ++-- .../rf_prune/query87.out | 51 ++-- .../shape/query38.out | 51 ++-- .../shape/query87.out | 51 ++-- 13 files changed, 457 insertions(+), 360 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java index f320e0029f89d2..9c06c531e7ddbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FuncDeps.java @@ -61,6 +61,7 @@ public int hashCode() { } private final Set items; + // determinants -> dependencies private final Map, Set>> edges; public FuncDeps() { @@ -144,21 +145,22 @@ public boolean isFuncDeps(Set dominate, Set dependency) { return items.contains(new FuncDepsItem(dominate, dependency)); } + public boolean isCircleDeps(Set dominate, Set dependency) { + return items.contains(new FuncDepsItem(dominate, dependency)) + && items.contains(new FuncDepsItem(dependency, dominate)); + } + /** - * find the cycle of dependencies + * find the determinants of dependencies */ - public Set> calBinaryDependencies(Set slotSet) { - Set> binaryDeps = new HashSet<>(); - Set> dependencies = edges.get(slotSet); - if (dependencies == null) { - return binaryDeps; - } - for (Set other : dependencies) { - if (edges.get(other).contains(slotSet)) { - binaryDeps.add(other); + public Set> findDeterminats(Set dependency) { + Set> determinants = new HashSet<>(); + for (FuncDepsItem item : items) { + if (item.dependencies.equals(dependency)) { + determinants.add(item.determinants); } } - return binaryDeps; + return determinants; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index e0e43e1517b3a6..4b89e3f0fa8297 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.properties.DataTrait; import org.apache.doris.nereids.properties.FuncDeps; import org.apache.doris.nereids.rules.Rule; @@ -36,14 +37,15 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Sets; -import org.apache.hadoop.yarn.webapp.hamlet2.Hamlet.P; import org.apache.thrift.annotation.Nullable; import java.util.ArrayList; +import java.util.BitSet; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; /** @@ -85,36 +87,72 @@ public List buildRules() { ); } - private List> collectContiguousInnerJoin() { - - } - private @Nullable Plan pushAgg(LogicalAggregate agg, LogicalJoin join) { - Plan foreign = tryExtractForeign(join); - if (foreign == null) { + InnerJoinCluster innerJoinCluster = new InnerJoinCluster(); + innerJoinCluster.collectContiguousInnerJoins(join); + if (!innerJoinCluster.isValid()) { return null; } - LogicalAggregate newAgg = tryGroupByForeign(agg, foreign); - if (newAgg == null) { - return null; + for (Entry> e : innerJoinCluster.getJoinsMap().entrySet()) { + LogicalJoin subJoin = e.getValue(); + Pair primaryAndForeign = tryExtractPrimaryForeign(subJoin); + if (primaryAndForeign == null) { + continue; + } + LogicalAggregate newAgg = eliminatePrimaryOutput(agg, primaryAndForeign.first, primaryAndForeign.second); + if (newAgg == null) { + return null; + } + LogicalJoin newJoin = innerJoinCluster + .constructJoinWithPrimary(e.getKey(), subJoin, primaryAndForeign.first); + if (newJoin != null && newJoin.left() == primaryAndForeign.first) { + return newJoin.withChildren(newJoin.left(), newAgg.withChildren(newJoin.right())); + } else if (newJoin != null && newJoin.right() == primaryAndForeign.first) { + return newJoin.withChildren(newAgg.withChildren(newJoin.left()), newJoin.right()); + } } + return null; + } - if (join.left() == foreign) { - return join.withChildren(newAgg, join.right()); - } else if (join.right() == foreign) { - return join.withChildren(join.left(), newAgg); + // eliminate the slot of primary plan in agg + private LogicalAggregate eliminatePrimaryOutput(LogicalAggregate agg, + Plan primary, Plan foreign) { + Set aggInputs = agg.getInputSlots(); + if (primary.getOutputSet().stream().noneMatch(aggInputs::contains)) { + return agg; } - return null; + Set primaryOutputSet = primary.getOutputSet(); + Set primarySlots = Sets.intersection(aggInputs, primaryOutputSet); + DataTrait dataTrait = agg.child().getLogicalProperties().getTrait(); + FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(foreign.getOutputSet(), primary.getOutputSet())); + HashMap primaryToForeignDeps = new HashMap<>(); + for (Slot slot : primarySlots) { + Set> replacedSlotSets = funcDeps.findDeterminats(ImmutableSet.of(slot)); + for (Set replacedSlots : replacedSlotSets) { + if (primaryOutputSet.stream().noneMatch(replacedSlots::contains) + && replacedSlots.size() == 1) { + primaryToForeignDeps.put(slot, replacedSlots.iterator().next()); + break; + } + } + } + + Set newGroupBySlots = constructNewGroupBy(agg, primaryOutputSet, primaryToForeignDeps); + List newOutput = constructNewOutput(agg, primaryOutputSet, primaryToForeignDeps, funcDeps); + if (newGroupBySlots == null || newOutput == null) { + return null; + } + return agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newOutput)); } - private @Nullable Set constructNewGroupBy(LogicalAggregate agg, Set foreignOutput, + private @Nullable Set constructNewGroupBy(LogicalAggregate agg, Set primaryOutputs, Map primaryToForeignBiDeps) { Set newGroupBySlots = new HashSet<>(); for (Expression expression : agg.getGroupByExpressions()) { if (!(expression instanceof Slot)) { return null; } - if (!foreignOutput.contains((Slot) expression) + if (primaryOutputs.contains((Slot) expression) && !primaryToForeignBiDeps.containsKey((Slot) expression)) { return null; } @@ -124,21 +162,37 @@ public List buildRules() { return newGroupBySlots; } - private @Nullable List constructNewOutput(LogicalAggregate agg, Set foreignOutput, - Map primaryToForeignBiDeps) { + private @Nullable List constructNewOutput(LogicalAggregate agg, Set primaryOutput, + Map primaryToForeignDeps, FuncDeps funcDeps) { List newOutput = new ArrayList<>(); for (NamedExpression expression : agg.getOutputExpressions()) { - if (expression instanceof Slot && primaryToForeignBiDeps.containsKey(expression)) { - expression = primaryToForeignBiDeps.getOrDefault(expression, expression.toSlot()); + // There are three cases for output expressions: + // 1. Slot: the slot is from primary plan, we need to replace it with + // the corresponding slot from foreign plan, + // or skip it when it isn't in group by. + // 2. Count: the count is from primary plan, + // we need to replace the slot in the count with the corresponding slot + // from foreign plan + // 3. Others: the expression is not from primary plan, we need to keep it + if (expression instanceof Slot && primaryToForeignDeps.containsKey(expression)) { + expression = primaryToForeignDeps.getOrDefault(expression, expression.toSlot()); } - if (expression instanceof Alias && expression.child(0) instanceof Count) { - expression = (NamedExpression) expression.rewriteUp(e -> - e instanceof Slot - ? primaryToForeignBiDeps.getOrDefault((Slot) e, (Slot) e) - : e); + if (expression instanceof Alias + && expression.child(0) instanceof Count + && expression.child(0).child(0) instanceof Slot) { + // count(slot) can be rewritten by circle deps + Slot slot = (Slot) expression.child(0).child(0); + if (primaryToForeignDeps.containsKey(slot) + && funcDeps.isCircleDeps( + ImmutableSet.of(slot), ImmutableSet.of(primaryToForeignDeps.get(slot)))) { + expression = (NamedExpression) expression.rewriteUp(e -> + e instanceof Slot + ? primaryToForeignDeps.getOrDefault((Slot) e, (Slot) e) + : e); + } } if (!(expression instanceof Slot) - && !foreignOutput.containsAll(expression.getInputSlots())) { + && expression.getInputSlots().stream().anyMatch(primaryOutput::contains)) { return null; } newOutput.add(expression); @@ -146,68 +200,118 @@ public List buildRules() { return newOutput; } - private @Nullable LogicalAggregate tryGroupByForeign( - LogicalAggregate agg, Plan foreign) { - Set groupBySlots = new HashSet<>(); - for (Expression expr : agg.getGroupByExpressions()) { - groupBySlots.addAll(expr.getInputSlots()); - } - if (foreign.getOutputSet().containsAll(groupBySlots)) { - return agg; - } - Set foreignOutput = foreign.getOutputSet(); - Set primarySlots = Sets.difference(groupBySlots, foreignOutput); - DataTrait dataTrait = agg.child().getLogicalProperties().getTrait(); - FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(groupBySlots, foreign.getOutputSet())); - HashMap primaryToForeignBiDeps = new HashMap<>(); - for (Slot slot : primarySlots) { - Set> replacedSlotSets = funcDeps.calBinaryDependencies(ImmutableSet.of(slot)); - for (Set replacedSlots : replacedSlotSets) { - if (foreignOutput.containsAll(replacedSlots) && replacedSlots.size() == 1) { - primaryToForeignBiDeps.put(slot, replacedSlots.iterator().next()); - break; - } - } - } - - Set newGroupBySlots = constructNewGroupBy(agg, foreignOutput, primaryToForeignBiDeps); - List newOutput = constructNewOutput(agg, foreignOutput, primaryToForeignBiDeps); - if (newGroupBySlots == null || newOutput == null) { - return null; - } - LogicalAggregate newAgg = - agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newOutput)); - return (LogicalAggregate) newAgg.withChildren(foreign); - } - - private @Nullable Plan tryExtractForeign(LogicalJoin join) { + private @Nullable Pair tryExtractPrimaryForeign(LogicalJoin join) { + Plan primary; Plan foreign; if (JoinUtils.canEliminateByFk(join, join.left(), join.right())) { + primary = join.left(); foreign = join.right(); } else if (JoinUtils.canEliminateByFk(join, join.right(), join.left())) { + primary = join.right(); foreign = join.left(); } else { return null; } - return foreign; + return Pair.of(primary, foreign); } static class InnerJoinCluster { - private List> contiguousInnerJoins = new ArrayList<>(); - private List leaf = new ArrayList<>(); + private final Map> innerJoins = new HashMap<>(); + private final List leaf = new ArrayList<>(); void collectContiguousInnerJoins(Plan plan) { - if (plan instanceof LogicalProject) { - boolean isSlotProject = ((LogicalProject) plan).getProjects().stream() - .allMatch(Slot.class::isInstance); - if (!isSlotProject) { - leaf.add(plan); - return; - } + if (!isSlotProject(plan) && !isInnerJoin(plan)) { + leaf.add(plan); + return; } for (Plan child : plan.children()) { collectContiguousInnerJoins(child); } + if (isInnerJoin(plan)) { + LogicalJoin join = (LogicalJoin) plan; + Set inputSlots = join.getInputSlots(); + BitSet childrenIndices = new BitSet(); + List children = new ArrayList<>(); + for (int i = 0; i < leaf.size(); i++) { + if (!Sets.intersection(leaf.get(i).getOutputSet(), inputSlots).isEmpty()) { + childrenIndices.set(i); + children.add(leaf.get(i)); + } + } + if (childrenIndices.cardinality() == 2) { + join = join.withChildren(children); + } + innerJoins.put(childrenIndices, join); + } + } + + boolean isValid() { + // we cannot handle the case that there is any join with more than 2 children + return innerJoins.keySet().stream().allMatch(x -> x.cardinality() == 2); + } + + @Nullable LogicalJoin constructJoinWithPrimary(BitSet bitSet, LogicalJoin join, Plan primary) { + Set forbiddenJoin = new HashSet<>(); + forbiddenJoin.add(bitSet); + BitSet totalBitset = new BitSet(); + totalBitset.set(0, leaf.size()); + totalBitset.set(leaf.indexOf(primary), false); + Plan childPlan = constructPlan(totalBitset, forbiddenJoin); + if (childPlan == null) { + return null; + } + return (LogicalJoin) join.withChildren(childPlan, primary); + } + + @Nullable Plan constructPlan(BitSet bitSet, Set forbiddenJoin) { + if (bitSet.cardinality() == 1) { + return leaf.get(bitSet.nextSetBit(0)); + } + + BitSet currentBitset = new BitSet(); + Plan currentPlan = null; + while (!currentBitset.equals(bitSet)) { + boolean addJoin = false; + for (Entry> entry : innerJoins.entrySet()) { + if (forbiddenJoin.contains(entry.getKey())) { + continue; + } + if (currentBitset.isEmpty()) { + addJoin = true; + currentBitset.or(entry.getKey()); + currentPlan = entry.getValue(); + forbiddenJoin.add(entry.getKey()); + } else if (currentBitset.intersects(entry.getKey())) { + addJoin = true; + currentBitset.or(entry.getKey()); + currentPlan = currentPlan.withChildren(currentPlan, entry.getValue()); + forbiddenJoin.add(entry.getKey()); + } + } + if (!addJoin) { + // if we cannot find any join to add, just return null + // It means we cannot construct a join + return null; + } + } + return currentPlan; + } + + Map> getJoinsMap() { + return innerJoins; + } + + boolean isSlotProject(Plan plan) { + return plan instanceof LogicalProject + && ((LogicalProject) (plan)).isAllSlots(); + + } + + boolean isInnerJoin(Plan plan) { + return plan instanceof LogicalJoin + && ((LogicalJoin) plan).getJoinType().isInnerJoin() + && !((LogicalJoin) plan).isMarkJoin() + && ((LogicalJoin) plan).getOtherJoinConjuncts().isEmpty(); } } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java index b2f01c780be289..91e66790002fb7 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFkTest.java @@ -66,7 +66,7 @@ void testGroupByFk() { PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalAggregate())) + .matches(logicalJoin(logicalAggregate(), any())) .printlnTree(); } @@ -77,21 +77,21 @@ void testGroupByFkAndOther() { PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalProject(logicalAggregate()))) + .matches(logicalJoin(logicalProject(logicalAggregate()), any())) .printlnTree(); sql = "select pri.id1 from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + "group by foreign_not_null.id2, pri.id1, pri.name"; PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalAggregate())) + .matches(logicalJoin(logicalAggregate(), any())) .printlnTree(); sql = "select pri.id1 from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + "group by foreign_not_null.id2, pri.id1, pri.name, foreign_not_null.name"; PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalProject(logicalAggregate()))) + .matches(logicalJoin(logicalProject(logicalAggregate()), any())) .printlnTree(); } @@ -102,14 +102,14 @@ void testGroupByFkWithCount() { PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalAggregate())) + .matches(logicalJoin(logicalAggregate(), any())) .printlnTree(); sql = "select count(foreign_not_null.id2) from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + "group by foreign_not_null.id2, pri.id1"; PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalAggregate())) + .matches(logicalJoin(logicalAggregate(), any())) .printlnTree(); } @@ -120,7 +120,7 @@ void testGroupByFkWithForeigAgg() { PlanChecker.from(connectContext) .analyze(sql) .rewrite() - .matches(logicalJoin(any(), logicalAggregate())) + .matches(logicalJoin(logicalAggregate(), any())) .printlnTree(); } @@ -134,4 +134,25 @@ void testGroupByFkWithPrimaryAgg() { .matches(logicalAggregate(logicalProject(logicalJoin()))) .printlnTree(); } + + @Test + void testMultiJoin() { + String sql = "select count(pri.id1), pri.name from foreign_not_null inner join foreign_null on foreign_null.name = foreign_not_null.name\n" + + " inner join pri on pri.id1 = foreign_not_null.id2\n" + + "group by foreign_not_null.id2, pri.id1, pri.name"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(logicalAggregate(), any())) + .printlnTree(); + + sql = "select count(pri.id1), pri.name from pri inner join foreign_not_null on pri.id1 = foreign_not_null.id2\n" + + "inner join foreign_null on foreign_null.name = foreign_not_null.name\n" + + "group by foreign_not_null.id2, pri.id1, pri.name"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches(logicalJoin(logicalAggregate(), any())) + .printlnTree(); + } } diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query38.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query38.out index 442e184d622e0f..6f065315ce78dc 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query38.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query38.out @@ -8,12 +8,12 @@ PhysicalResultSink ----------hashAgg[LOCAL] ------------PhysicalProject --------------PhysicalIntersect -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ws_bill_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ws_bill_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ws_sold_date_sk] --------------------------------PhysicalProject @@ -22,15 +22,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1200) and (date_dim.d_month_seq >= 1189)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] --------------------------------PhysicalProject @@ -39,15 +38,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1200) and (date_dim.d_month_seq >= 1189)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ss_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ss_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject @@ -56,7 +54,6 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1200) and (date_dim.d_month_seq >= 1189)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query87.out b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query87.out index 20ece13139de00..181069d1e2c8bc 100644 --- a/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query87.out +++ b/regression-test/data/nereids_tpcds_shape_sf1000_p0/shape/query87.out @@ -6,12 +6,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalProject ----------PhysicalExcept -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] ----------------------------PhysicalProject @@ -20,15 +20,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1213) and (date_dim.d_month_seq >= 1202)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] ----------------------------PhysicalProject @@ -37,15 +36,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1213) and (date_dim.d_month_seq >= 1202)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ws_bill_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ws_bill_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ws_sold_date_sk] ----------------------------PhysicalProject @@ -54,7 +52,6 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1213) and (date_dim.d_month_seq >= 1202)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query38.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query38.out index c20ce4bb74100d..cae7a8949877ec 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query38.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query38.out @@ -8,12 +8,12 @@ PhysicalResultSink ----------hashAgg[LOCAL] ------------PhysicalProject --------------PhysicalIntersect -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() ---------------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ws_sold_date_sk] --------------------------------PhysicalProject @@ -22,15 +22,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] --------------------------------PhysicalProject @@ -39,15 +38,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject @@ -56,7 +54,6 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query87.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query87.out index 2a1cbdc26556b2..9fb612c62ce130 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query87.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/noStatsRfPrune/query87.out @@ -6,12 +6,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalProject ----------PhysicalExcept -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() -----------------------PhysicalDistribute[DistributionSpecHash] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] ----------------------------PhysicalProject @@ -20,15 +20,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] ----------------------------PhysicalProject @@ -37,15 +36,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ws_sold_date_sk] ----------------------------PhysicalProject @@ -54,7 +52,6 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query38.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query38.out index ff4380040ba0c7..3a1334802da1a2 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query38.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query38.out @@ -8,12 +8,12 @@ PhysicalResultSink ----------hashAgg[LOCAL] ------------PhysicalProject --------------PhysicalIntersect -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ws_bill_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ws_bill_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ws_sold_date_sk] --------------------------------PhysicalProject @@ -22,15 +22,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] --------------------------------PhysicalProject @@ -39,15 +38,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ss_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ss_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject @@ -56,7 +54,6 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query87.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query87.out index abf84ad40ba111..66f357bee2c2df 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query87.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/no_stats_shape/query87.out @@ -6,12 +6,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalProject ----------PhysicalExcept -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] ----------------------------PhysicalProject @@ -20,15 +20,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] ----------------------------PhysicalProject @@ -37,15 +36,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ws_bill_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ws_bill_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ws_sold_date_sk] ----------------------------PhysicalProject @@ -54,7 +52,6 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query38.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query38.out index c20ce4bb74100d..cae7a8949877ec 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query38.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query38.out @@ -8,12 +8,12 @@ PhysicalResultSink ----------hashAgg[LOCAL] ------------PhysicalProject --------------PhysicalIntersect -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() ---------------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ws_sold_date_sk] --------------------------------PhysicalProject @@ -22,15 +22,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] --------------------------------PhysicalProject @@ -39,15 +38,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject @@ -56,7 +54,6 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query87.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query87.out index 2a1cbdc26556b2..9fb612c62ce130 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query87.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/rf_prune/query87.out @@ -6,12 +6,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalProject ----------PhysicalExcept -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() -----------------------PhysicalDistribute[DistributionSpecHash] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] ----------------------------PhysicalProject @@ -20,15 +20,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] ----------------------------PhysicalProject @@ -37,15 +36,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ws_sold_date_sk] ----------------------------PhysicalProject @@ -54,7 +52,6 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query38.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query38.out index ff4380040ba0c7..3a1334802da1a2 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query38.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query38.out @@ -8,12 +8,12 @@ PhysicalResultSink ----------hashAgg[LOCAL] ------------PhysicalProject --------------PhysicalIntersect -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ws_bill_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ws_bill_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ws_sold_date_sk] --------------------------------PhysicalProject @@ -22,15 +22,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] --------------------------------PhysicalProject @@ -39,15 +38,14 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] -----------------hashAgg[GLOBAL] -------------------PhysicalDistribute[DistributionSpecHash] ---------------------hashAgg[LOCAL] -----------------------PhysicalProject -------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ss_customer_sk] ---------------------------PhysicalDistribute[DistributionSpecHash] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ss_customer_sk] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[GLOBAL] +------------------------PhysicalDistribute[DistributionSpecHash] +--------------------------hashAgg[LOCAL] ----------------------------PhysicalProject ------------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ss_sold_date_sk] --------------------------------PhysicalProject @@ -56,7 +54,6 @@ PhysicalResultSink ----------------------------------PhysicalProject ------------------------------------filter((date_dim.d_month_seq <= 1194) and (date_dim.d_month_seq >= 1183)) --------------------------------------PhysicalOlapScan[date_dim] ---------------------------PhysicalDistribute[DistributionSpecHash] -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[customer] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query87.out b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query87.out index abf84ad40ba111..66f357bee2c2df 100644 --- a/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query87.out +++ b/regression-test/data/nereids_tpcds_shape_sf100_p0/shape/query87.out @@ -6,12 +6,12 @@ PhysicalResultSink ------hashAgg[LOCAL] --------PhysicalProject ----------PhysicalExcept -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF1 c_customer_sk->[ss_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((store_sales.ss_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF0 d_date_sk->[ss_sold_date_sk] ----------------------------PhysicalProject @@ -20,15 +20,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF3 c_customer_sk->[cs_bill_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((catalog_sales.cs_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF2 d_date_sk->[cs_sold_date_sk] ----------------------------PhysicalProject @@ -37,15 +36,14 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] -------------hashAgg[GLOBAL] ---------------PhysicalDistribute[DistributionSpecHash] -----------------hashAgg[LOCAL] -------------------PhysicalProject ---------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ws_bill_customer_sk] -----------------------PhysicalDistribute[DistributionSpecHash] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] +------------PhysicalDistribute[DistributionSpecHash] +--------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_bill_customer_sk = customer.c_customer_sk)) otherCondition=() build RFs:RF5 c_customer_sk->[ws_bill_customer_sk] +----------------PhysicalDistribute[DistributionSpecHash] +------------------hashAgg[GLOBAL] +--------------------PhysicalDistribute[DistributionSpecHash] +----------------------hashAgg[LOCAL] ------------------------PhysicalProject --------------------------hashJoin[INNER_JOIN] hashCondition=((web_sales.ws_sold_date_sk = date_dim.d_date_sk)) otherCondition=() build RFs:RF4 d_date_sk->[ws_sold_date_sk] ----------------------------PhysicalProject @@ -54,7 +52,6 @@ PhysicalResultSink ------------------------------PhysicalProject --------------------------------filter((date_dim.d_month_seq <= 1195) and (date_dim.d_month_seq >= 1184)) ----------------------------------PhysicalOlapScan[date_dim] -----------------------PhysicalDistribute[DistributionSpecHash] -------------------------PhysicalProject ---------------------------PhysicalOlapScan[customer] +----------------PhysicalDistribute[DistributionSpecHash] +------------------PhysicalOlapScan[customer] From 6b3a67b785ede2462ce60eb06fb06c06d75c5415 Mon Sep 17 00:00:00 2001 From: xiejiann Date: Thu, 27 Jun 2024 09:58:28 +0800 Subject: [PATCH 7/9] add comment --- .../rewrite/PushDownAggThroughJoinOnPkFk.java | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index 4b89e3f0fa8297..08d53acaf6b089 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -200,6 +200,7 @@ private LogicalAggregate eliminatePrimaryOutput(LogicalAggregate agg, return newOutput; } + // try to extract primary key table and foreign key table private @Nullable Pair tryExtractPrimaryForeign(LogicalJoin join) { Plan primary; Plan foreign; @@ -215,6 +216,33 @@ private LogicalAggregate eliminatePrimaryOutput(LogicalAggregate agg, return Pair.of(primary, foreign); } + /** + * This class flattens nested join clusters and optimizes aggregation pushdown. + * + * Example of flattening: + * Join1 Join1 Join2 + * / \ / \ / \ + * a Join2 =====> a b b c + * / \ + * b c + * + * After flattening, we attempt to push down aggregations for each join. + * For instance, if b is a primary key table and c is a foreign key table: + * + * Original (can't push down): After flattening (can push down): + * agg(Join1) Join1 Join2 + * / \ / \ / \ + * a Join2 =====> a b b agg(c) + * / \ + * b c + * + * Finally, we can reorganize the join tree: + * Join2 + * / \ + * agg(c) Join1 + * / \ + * a b + */ static class InnerJoinCluster { private final Map> innerJoins = new HashMap<>(); private final List leaf = new ArrayList<>(); From 7ddf8dd9bf935ef00a986de561cd364bb5298555 Mon Sep 17 00:00:00 2001 From: xiejiann Date: Thu, 27 Jun 2024 14:57:27 +0800 Subject: [PATCH 8/9] remove inner join check --- .../nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index 08d53acaf6b089..35b8a593e38537 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -68,8 +68,7 @@ public List buildRules() { return ImmutableList.of( logicalAggregate( innerLogicalJoin() - .when(j -> j.getJoinType().isInnerJoin() - && !j.isMarkJoin() + .when(j -> !j.isMarkJoin() && j.getOtherJoinConjuncts().isEmpty())) .when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance)) .thenApply(ctx -> pushAgg(ctx.root, ctx.root.child())) From 08c0b69e4584c271b25371ff89dc2bebbd6ed8eb Mon Sep 17 00:00:00 2001 From: xiejiann Date: Thu, 27 Jun 2024 15:10:07 +0800 Subject: [PATCH 9/9] skip primary slot when constructing agg output --- .../rewrite/PushDownAggThroughJoinOnPkFk.java | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java index 35b8a593e38537..0fb3ed11562945 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAggThroughJoinOnPkFk.java @@ -137,7 +137,8 @@ private LogicalAggregate eliminatePrimaryOutput(LogicalAggregate agg, } Set newGroupBySlots = constructNewGroupBy(agg, primaryOutputSet, primaryToForeignDeps); - List newOutput = constructNewOutput(agg, primaryOutputSet, primaryToForeignDeps, funcDeps); + List newOutput = constructNewOutput( + agg, primaryOutputSet, primaryToForeignDeps, funcDeps, primary); if (newGroupBySlots == null || newOutput == null) { return null; } @@ -162,7 +163,7 @@ private LogicalAggregate eliminatePrimaryOutput(LogicalAggregate agg, } private @Nullable List constructNewOutput(LogicalAggregate agg, Set primaryOutput, - Map primaryToForeignDeps, FuncDeps funcDeps) { + Map primaryToForeignDeps, FuncDeps funcDeps, Plan primaryPlan) { List newOutput = new ArrayList<>(); for (NamedExpression expression : agg.getOutputExpressions()) { // There are three cases for output expressions: @@ -172,9 +173,12 @@ private LogicalAggregate eliminatePrimaryOutput(LogicalAggregate agg, // 2. Count: the count is from primary plan, // we need to replace the slot in the count with the corresponding slot // from foreign plan - // 3. Others: the expression is not from primary plan, we need to keep it - if (expression instanceof Slot && primaryToForeignDeps.containsKey(expression)) { - expression = primaryToForeignDeps.getOrDefault(expression, expression.toSlot()); + if (expression instanceof Slot && primaryPlan.getOutput().contains(expression)) { + if (primaryToForeignDeps.containsKey(expression)) { + expression = primaryToForeignDeps.getOrDefault(expression, expression.toSlot()); + } else { + continue; + } } if (expression instanceof Alias && expression.child(0) instanceof Count