diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index 23711618d6d94a..347f6e7915bf13 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -67,6 +67,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; import org.apache.hadoop.util.Lists; import java.util.ArrayList; @@ -600,8 +601,8 @@ public boolean couldPruneColumnOnProducer(CTEId cteId) { return consumerIds.size() == this.statementContext.getCteIdToConsumers().get(cteId).size(); } - public void addCTEConsumerGroup(CTEId cteId, Group g, Map producerSlotToConsumerSlot) { - List, Group>> consumerGroups = + public void addCTEConsumerGroup(CTEId cteId, Group g, Multimap producerSlotToConsumerSlot) { + List, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().computeIfAbsent(cteId, k -> new ArrayList<>()); consumerGroups.add(Pair.of(producerSlotToConsumerSlot, g)); } @@ -610,12 +611,18 @@ public void addCTEConsumerGroup(CTEId cteId, Group g, Map producerSl * Update CTE consumer group as producer's stats update */ public void updateConsumerStats(CTEId cteId, Statistics statistics) { - List, Group>> consumerGroups = this.statementContext.getCteIdToConsumerGroup().get(cteId); - for (Pair, Group> p : consumerGroups) { - Map producerSlotToConsumerSlot = p.first; + List, Group>> consumerGroups + = this.statementContext.getCteIdToConsumerGroup().get(cteId); + for (Pair, Group> p : consumerGroups) { + Multimap producerSlotToConsumerSlot = p.first; Statistics updatedConsumerStats = new Statistics(statistics); for (Entry entry : statistics.columnStatistics().entrySet()) { - updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue()); + if (!(entry.getKey() instanceof Slot)) { + continue; + } + for (Slot consumer : producerSlotToConsumerSlot.get((Slot) entry.getKey())) { + updatedConsumerStats.addColumnStats(consumer, entry.getValue()); + } } p.value().setStatistics(updatedConsumerStats); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java index 357f0286b8d155..c6bc2123082a9b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/StatementContext.java @@ -40,6 +40,7 @@ import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; import com.google.common.collect.Sets; import java.util.ArrayList; @@ -92,7 +93,7 @@ public class StatementContext { private final Map> consumerIdToFilters = new HashMap<>(); private final Map> cteIdToConsumerUnderProjects = new HashMap<>(); // Used to update consumer's stats - private final Map, Group>>> cteIdToConsumerGroup = new HashMap<>(); + private final Map, Group>>> cteIdToConsumerGroup = new HashMap<>(); private final Map rewrittenCteProducer = new HashMap<>(); private final Map rewrittenCteConsumer = new HashMap<>(); @@ -229,7 +230,7 @@ public Map> getCteIdToConsumerUnderProjects() { return cteIdToConsumerUnderProjects; } - public Map, Group>>> getCteIdToConsumerGroup() { + public Map, Group>>> getCteIdToConsumerGroup() { return cteIdToConsumerGroup; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 2d0822fbcb9657..64bcc550183930 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -987,9 +987,10 @@ public PlanFragment visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, // update expr to slot mapping for (Slot producerSlot : cteProducer.getOutput()) { - Slot consumerSlot = cteConsumer.getProducerToConsumerSlotMap().get(producerSlot); - SlotRef slotRef = context.findSlotRef(producerSlot.getExprId()); - context.addExprIdSlotRefPair(consumerSlot.getExprId(), slotRef); + for (Slot consumerSlot : cteConsumer.getProducerToConsumerSlotMap().get(producerSlot)) { + SlotRef slotRef = context.findSlotRef(producerSlot.getExprId()); + context.addExprIdSlotRefPair(consumerSlot.getExprId(), slotRef); + } } return multiCastFragment; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java index b3c0d77b1e79be..c4a27b9748d7c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AdjustNullable.java @@ -49,8 +49,10 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.LinkedHashMultimap; import com.google.common.collect.Lists; import com.google.common.collect.Maps; +import com.google.common.collect.Multimap; import java.util.LinkedHashMap; import java.util.List; @@ -238,14 +240,15 @@ public Plan visitLogicalPartitionTopN(LogicalPartitionTopN parti @Override public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Map replaceMap) { Map consumerToProducerOutputMap = new LinkedHashMap<>(); - Map producerToConsumerOutputMap = new LinkedHashMap<>(); + Multimap producerToConsumerOutputMap = LinkedHashMultimap.create(); for (Slot producerOutputSlot : cteConsumer.getConsumerToProducerOutputMap().values()) { Slot newProducerOutputSlot = updateExpression(producerOutputSlot, replaceMap); - Slot newConsumerOutputSlot = cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot) - .withNullable(newProducerOutputSlot.nullable()); - producerToConsumerOutputMap.put(newProducerOutputSlot, newConsumerOutputSlot); - consumerToProducerOutputMap.put(newConsumerOutputSlot, newProducerOutputSlot); - replaceMap.put(newConsumerOutputSlot.getExprId(), newConsumerOutputSlot); + for (Slot consumerOutputSlot : cteConsumer.getProducerToConsumerOutputMap().get(producerOutputSlot)) { + Slot newConsumerOutputSlot = consumerOutputSlot.withNullable(newProducerOutputSlot.nullable()); + producerToConsumerOutputMap.put(newProducerOutputSlot, newConsumerOutputSlot); + consumerToProducerOutputMap.put(newConsumerOutputSlot, newProducerOutputSlot); + replaceMap.put(newConsumerOutputSlot.getExprId(), newConsumerOutputSlot); + } } return cteConsumer.withTwoMaps(consumerToProducerOutputMap, producerToConsumerOutputMap); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java index 187ec24c4f9642..36a33a6dafa47c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/OrExpansion.java @@ -166,8 +166,13 @@ private Plan expandLeftAntiJoin(CascadesContext ctx, ctx.putCTEIdToConsumer(left); ctx.putCTEIdToConsumer(right); - Map replaced = new HashMap<>(left.getProducerToConsumerOutputMap()); - replaced.putAll(right.getProducerToConsumerOutputMap()); + Map replaced = new HashMap<>(); + for (Map.Entry entry : left.getConsumerToProducerOutputMap().entrySet()) { + replaced.put(entry.getValue(), entry.getKey()); + } + for (Map.Entry entry : right.getConsumerToProducerOutputMap().entrySet()) { + replaced.put(entry.getValue(), entry.getKey()); + } List disjunctions = hashOtherConditions.first; List otherConditions = hashOtherConditions.second; List newOtherConditions = otherConditions.stream() @@ -189,8 +194,13 @@ private Plan expandLeftAntiJoin(CascadesContext ctx, LogicalCTEConsumer newRight = new LogicalCTEConsumer( ctx.getStatementContext().getNextRelationId(), rightProducer.getCteId(), "", rightProducer); ctx.putCTEIdToConsumer(newRight); - Map newReplaced = new HashMap<>(left.getProducerToConsumerOutputMap()); - newReplaced.putAll(newRight.getProducerToConsumerOutputMap()); + Map newReplaced = new HashMap<>(); + for (Map.Entry entry : left.getConsumerToProducerOutputMap().entrySet()) { + newReplaced.put(entry.getValue(), entry.getKey()); + } + for (Map.Entry entry : newRight.getConsumerToProducerOutputMap().entrySet()) { + newReplaced.put(entry.getValue(), entry.getKey()); + } newOtherConditions = otherConditions.stream() .map(e -> e.rewriteUp(s -> newReplaced.containsKey(s) ? newReplaced.get(s) : s)) .collect(Collectors.toList()); @@ -246,8 +256,13 @@ private List expandInnerJoin(CascadesContext ctx, Pair, ctx.putCTEIdToConsumer(right); //rewrite conjuncts to replace the old slots with CTE slots - Map replaced = new HashMap<>(left.getProducerToConsumerOutputMap()); - replaced.putAll(right.getProducerToConsumerOutputMap()); + Map replaced = new HashMap<>(); + for (Map.Entry entry : left.getConsumerToProducerOutputMap().entrySet()) { + replaced.put(entry.getValue(), entry.getKey()); + } + for (Map.Entry entry : right.getConsumerToProducerOutputMap().entrySet()) { + replaced.put(entry.getValue(), entry.getKey()); + } List hashCond = pair.first.stream() .map(e -> e.rewriteUp(s -> replaced.containsKey(s) ? replaced.get(s) : s)) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index 3ae47c3a3edaa0..9f2baf19903223 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -67,6 +67,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.LinkedHashMultimap; +import com.google.common.collect.Multimap; import java.util.LinkedHashMap; import java.util.List; @@ -448,7 +450,7 @@ public Plan visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, DeepCopierCo return context.getRelationReplaceMap().get(cteConsumer.getRelationId()); } Map consumerToProducerOutputMap = new LinkedHashMap<>(); - Map producerToConsumerOutputMap = new LinkedHashMap<>(); + Multimap producerToConsumerOutputMap = LinkedHashMultimap.create(); for (Slot consumerOutput : cteConsumer.getOutput()) { Slot newOutput = (Slot) ExpressionDeepCopier.INSTANCE.deepCopy(consumerOutput, context); consumerToProducerOutputMap.put(newOutput, cteConsumer.getProducerSlot(consumerOutput)); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java index c96d71e5daa0b2..49ca1e44a2df66 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCTEConsumer.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.RelationId; @@ -30,8 +31,10 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; @@ -45,20 +48,15 @@ public class LogicalCTEConsumer extends LogicalRelation { private final String name; private final CTEId cteId; private final Map consumerToProducerOutputMap; - private final Map producerToConsumerOutputMap; + private final Multimap producerToConsumerOutputMap; /** * Logical CTE consumer. */ public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name, - Map consumerToProducerOutputMap, Map producerToConsumerOutputMap) { - super(relationId, PlanType.LOGICAL_CTE_CONSUMER, Optional.empty(), Optional.empty()); - this.cteId = Objects.requireNonNull(cteId, "cteId should not null"); - this.name = Objects.requireNonNull(name, "name should not null"); - this.consumerToProducerOutputMap = Objects.requireNonNull(consumerToProducerOutputMap, - "consumerToProducerOutputMap should not null"); - this.producerToConsumerOutputMap = Objects.requireNonNull(producerToConsumerOutputMap, - "producerToConsumerOutputMap should not null"); + Map consumerToProducerOutputMap, Multimap producerToConsumerOutputMap) { + this(relationId, cteId, name, consumerToProducerOutputMap, producerToConsumerOutputMap, + Optional.empty(), Optional.empty()); } /** @@ -68,16 +66,23 @@ public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name, Logic super(relationId, PlanType.LOGICAL_CTE_CONSUMER, Optional.empty(), Optional.empty()); this.cteId = Objects.requireNonNull(cteId, "cteId should not null"); this.name = Objects.requireNonNull(name, "name should not null"); - this.consumerToProducerOutputMap = new LinkedHashMap<>(); - this.producerToConsumerOutputMap = new LinkedHashMap<>(); - initOutputMaps(producerPlan); + ImmutableMap.Builder cToPBuilder = ImmutableMap.builder(); + ImmutableMultimap.Builder pToCBuilder = ImmutableMultimap.builder(); + List producerOutput = producerPlan.getOutput(); + for (Slot producerOutputSlot : producerOutput) { + Slot consumerSlot = generateConsumerSlot(this.name, producerOutputSlot); + cToPBuilder.put(consumerSlot, producerOutputSlot); + pToCBuilder.put(producerOutputSlot, consumerSlot); + } + consumerToProducerOutputMap = cToPBuilder.build(); + producerToConsumerOutputMap = pToCBuilder.build(); } /** * Logical CTE consumer. */ public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name, - Map consumerToProducerOutputMap, Map producerToConsumerOutputMap, + Map consumerToProducerOutputMap, Multimap producerToConsumerOutputMap, Optional groupExpression, Optional logicalProperties) { super(relationId, PlanType.LOGICAL_CTE_CONSUMER, groupExpression, logicalProperties); this.cteId = Objects.requireNonNull(cteId, "cteId should not null"); @@ -88,21 +93,24 @@ public LogicalCTEConsumer(RelationId relationId, CTEId cteId, String name, "producerToConsumerOutputMap should not null"); } - private void initOutputMaps(LogicalPlan childPlan) { - List producerOutput = childPlan.getOutput(); - for (Slot producerOutputSlot : producerOutput) { - Slot consumerSlot = new SlotReference(producerOutputSlot.getName(), - producerOutputSlot.getDataType(), producerOutputSlot.nullable(), ImmutableList.of(name)); - producerToConsumerOutputMap.put(producerOutputSlot, consumerSlot); - consumerToProducerOutputMap.put(consumerSlot, producerOutputSlot); - } + /** + * generate a consumer slot mapping from producer slot. + */ + public static SlotReference generateConsumerSlot(String cteName, Slot producerOutputSlot) { + SlotReference slotRef = + producerOutputSlot instanceof SlotReference ? (SlotReference) producerOutputSlot : null; + return new SlotReference(StatementScopeIdGenerator.newExprId(), + producerOutputSlot.getName(), producerOutputSlot.getDataType(), + producerOutputSlot.nullable(), ImmutableList.of(cteName), + slotRef != null ? (slotRef.getColumn().isPresent() ? slotRef.getColumn().get() : null) : null, + slotRef != null ? Optional.of(slotRef.getInternalName()) : Optional.empty()); } public Map getConsumerToProducerOutputMap() { return consumerToProducerOutputMap; } - public Map getProducerToConsumerOutputMap() { + public Multimap getProducerToConsumerOutputMap() { return producerToConsumerOutputMap; } @@ -111,7 +119,8 @@ public R accept(PlanVisitor visitor, C context) { return visitor.visitLogicalCTEConsumer(this, context); } - public Plan withTwoMaps(Map consumerToProducerOutputMap, Map producerToConsumerOutputMap) { + public Plan withTwoMaps(Map consumerToProducerOutputMap, + Multimap producerToConsumerOutputMap) { return new LogicalCTEConsumer(relationId, cteId, name, consumerToProducerOutputMap, producerToConsumerOutputMap, Optional.empty(), Optional.empty()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java index 260b93e89f7370..2e902f648777b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalCTEConsumer.java @@ -31,6 +31,8 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableMultimap; +import com.google.common.collect.Multimap; import java.util.List; import java.util.Map; @@ -43,14 +45,14 @@ public class PhysicalCTEConsumer extends PhysicalRelation { private final CTEId cteId; - private final Map producerToConsumerSlotMap; private final Map consumerToProducerSlotMap; + private final Multimap producerToConsumerSlotMap; /** * Constructor */ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, Map consumerToProducerSlotMap, - Map producerToConsumerSlotMap, LogicalProperties logicalProperties) { + Multimap producerToConsumerSlotMap, LogicalProperties logicalProperties) { this(relationId, cteId, consumerToProducerSlotMap, producerToConsumerSlotMap, Optional.empty(), logicalProperties); } @@ -59,7 +61,7 @@ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, Map c * Constructor */ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, - Map consumerToProducerSlotMap, Map producerToConsumerSlotMap, + Map consumerToProducerSlotMap, Multimap producerToConsumerSlotMap, Optional groupExpression, LogicalProperties logicalProperties) { this(relationId, cteId, consumerToProducerSlotMap, producerToConsumerSlotMap, groupExpression, logicalProperties, null, null); @@ -69,14 +71,14 @@ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, * Constructor */ public PhysicalCTEConsumer(RelationId relationId, CTEId cteId, Map consumerToProducerSlotMap, - Map producerToConsumerSlotMap, Optional groupExpression, + Multimap producerToConsumerSlotMap, Optional groupExpression, LogicalProperties logicalProperties, PhysicalProperties physicalProperties, Statistics statistics) { super(relationId, PlanType.PHYSICAL_CTE_CONSUMER, groupExpression, logicalProperties, physicalProperties, statistics); this.cteId = cteId; this.consumerToProducerSlotMap = ImmutableMap.copyOf(Objects.requireNonNull( consumerToProducerSlotMap, "consumerToProducerSlotMap should not null")); - this.producerToConsumerSlotMap = ImmutableMap.copyOf(Objects.requireNonNull( + this.producerToConsumerSlotMap = ImmutableMultimap.copyOf(Objects.requireNonNull( producerToConsumerSlotMap, "consumerToProducerSlotMap should not null")); } @@ -84,7 +86,7 @@ public CTEId getCteId() { return cteId; } - public Map getProducerToConsumerSlotMap() { + public Multimap getProducerToConsumerSlotMap() { return producerToConsumerSlotMap; } diff --git a/regression-test/suites/nereids_p0/cte/test_cte_with_duplicate_consumer.groovy b/regression-test/suites/nereids_p0/cte/test_cte_with_duplicate_consumer.groovy new file mode 100644 index 00000000000000..1a92f62b3acfa5 --- /dev/null +++ b/regression-test/suites/nereids_p0/cte/test_cte_with_duplicate_consumer.groovy @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +suite("test_cte_with_duplicate_consumer") { + test { + sql """ + WITH cte1(col1) AS (SELECT 1), cte2(col2_1, col2_2) AS (SELECT col1, col1 FROM cte1) SELECT * FROM cte2 + """ + + result([[1, 1]]) + } +} +