diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java index dff58a97e5119b..8f1f68d15967c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/CascadesContext.java @@ -20,6 +20,7 @@ import org.apache.doris.catalog.Database; import org.apache.doris.catalog.Env; import org.apache.doris.catalog.Table; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.analyzer.Scope; import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.jobs.Job; @@ -34,6 +35,7 @@ import org.apache.doris.nereids.jobs.scheduler.JobStack; import org.apache.doris.nereids.jobs.scheduler.ScheduleContext; import org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler; +import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.Memo; import org.apache.doris.nereids.processor.post.RuntimeFilterContext; import org.apache.doris.nereids.properties.PhysicalProperties; @@ -44,6 +46,7 @@ import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver; import org.apache.doris.nereids.trees.expressions.CTEId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SubqueryExpr; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCTE; @@ -54,6 +57,8 @@ import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.SessionVariable; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.Statistics; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -63,6 +68,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Optional; import java.util.Set; import java.util.Stack; @@ -103,6 +109,9 @@ public class CascadesContext implements ScheduleContext { private Map> consumerIdToFilters = new HashMap<>(); private Map> cteIdToConsumerUnderProjects = new HashMap<>(); + // Used to update consumer's stats + private Map, Group>>> cteIdToConsumerGroup = new HashMap<>(); + public CascadesContext(Plan plan, Memo memo, StatementContext statementContext, PhysicalProperties requestProperties) { this(plan, memo, statementContext, new CTEContext(), requestProperties); @@ -565,4 +574,25 @@ public boolean couldPruneColumnOnProducer(CTEId cteId) { Set consumerIds = this.cteIdToConsumerUnderProjects.get(cteId); return consumerIds.size() == this.cteIdToConsumers.get(cteId).size(); } + + public void addCTEConsumerGroup(CTEId cteId, Group g, Map producerSlotToConsumerSlot) { + List, Group>> consumerGroups = + this.cteIdToConsumerGroup.computeIfAbsent(cteId, k -> new ArrayList<>()); + consumerGroups.add(Pair.of(producerSlotToConsumerSlot, g)); + } + + /** + * Update CTE consumer group as producer's stats update + */ + public void updateConsumerStats(CTEId cteId, Statistics statistics) { + List, Group>> consumerGroups = this.cteIdToConsumerGroup.get(cteId); + for (Pair, Group> p : consumerGroups) { + Map producerSlotToConsumerSlot = p.first; + Statistics updatedConsumerStats = new Statistics(statistics); + for (Entry entry : statistics.columnStatistics().entrySet()) { + updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue()); + } + p.value().setStatistics(updatedConsumerStats); + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java index d8e79bd80fbc1a..8ba55542945adc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/CostAndEnforcerJob.java @@ -245,7 +245,8 @@ private boolean calculateEnforce(List requestChildrenPropert StatsCalculator statsCalculator = StatsCalculator.estimate(groupExpression, context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(), context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(), - context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump()); + context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(), + context.getCascadesContext()); if (!context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump() && context.getCascadesContext().getConnectContext().getSessionVariable().isEnableMinidump()) { context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java index 6ff01f278dec29..cfe952c0f26873 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java @@ -105,7 +105,8 @@ public void execute() { context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(), context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(), context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(), - cteIdToStats); + cteIdToStats, + context.getCascadesContext()); STATS_STATE_TRACER.log(StatsStateEvent.of(groupExpression, groupExpression.getOwnerGroup().getStatistics())); if (ConnectContext.get().getSessionVariable().isEnableMinidump() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java index 385c0227ea2687..9883815b43fcc2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/StatsCalculator.java @@ -24,6 +24,7 @@ import org.apache.doris.catalog.TableIf; import org.apache.doris.common.Config; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; @@ -158,14 +159,17 @@ public class StatsCalculator extends DefaultPlanVisitor { private Map cteIdToStats; + private CascadesContext cascadesContext; + private StatsCalculator(GroupExpression groupExpression, boolean forbidUnknownColStats, Map columnStatisticMap, boolean isPlayNereidsDump, - Map cteIdToStats) { + Map cteIdToStats, CascadesContext context) { this.groupExpression = groupExpression; this.forbidUnknownColStats = forbidUnknownColStats; this.totalColumnStatisticMap = columnStatisticMap; this.isPlayNereidsDump = isPlayNereidsDump; this.cteIdToStats = Objects.requireNonNull(cteIdToStats, "CTEIdToStats can't be null"); + this.cascadesContext = context; } public Map getTotalHistogramMap() { @@ -189,25 +193,26 @@ public void setTotalColumnStatisticMap(Map totalColumnS */ public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats, Map columnStatisticMap, boolean isPlayNereidsDump, - Map cteIdToStats) { + Map cteIdToStats, CascadesContext context) { StatsCalculator statsCalculator = new StatsCalculator( - groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats); + groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats, context); statsCalculator.estimate(); return statsCalculator; } public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats, - Map columnStatisticMap, boolean isPlayNereidsDump) { + Map columnStatisticMap, boolean isPlayNereidsDump, CascadesContext context) { return StatsCalculator.estimate(groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, - new HashMap<>()); + new HashMap<>(), context); } - public static void estimate(GroupExpression groupExpression) { + // For unit test only + public static void estimate(GroupExpression groupExpression, CascadesContext context) { StatsCalculator statsCalculator = new StatsCalculator(groupExpression, false, - new HashMap<>(), false, Collections.EMPTY_MAP); + new HashMap<>(), false, Collections.EMPTY_MAP, context); statsCalculator.estimate(); } @@ -997,6 +1002,8 @@ public Statistics visitLogicalCTEProducer(LogicalCTEProducer cte @Override public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void context) { CTEId cteId = cteConsumer.getCteId(); + cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(), + cteConsumer.getProducerToConsumerOutputMap()); Statistics prodStats = cteIdToStats.get(cteId); Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId)); Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>()); @@ -1021,11 +1028,14 @@ public Statistics visitPhysicalCTEProducer(PhysicalCTEProducer c Void context) { Statistics statistics = groupExpression.childStatistics(0); cteIdToStats.put(cteProducer.getCteId(), statistics); + cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics); return statistics; } @Override public Statistics visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, Void context) { + cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(), + cteConsumer.getProducerToConsumerSlotMap()); CTEId cteId = cteConsumer.getCteId(); Statistics prodStats = cteIdToStats.get(cteId); if (prodStats == null) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java index 8b928f4215e247..2ceca6f8ec1e3d 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java @@ -144,14 +144,14 @@ public void testFilter() { GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup)); Group ownerGroup = newGroup(); groupExpression.setOwnerGroup(ownerGroup); - StatsCalculator.estimate(groupExpression); + StatsCalculator.estimate(groupExpression, null); Assertions.assertEquals((10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001); LogicalFilter logicalFilterOr = new LogicalFilter<>(or, groupPlan); GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup)); Group ownerGroupOr = newGroup(); groupExpressionOr.setOwnerGroup(ownerGroupOr); - StatsCalculator.estimate(groupExpressionOr); + StatsCalculator.estimate(groupExpressionOr, null); Assertions.assertEquals((long) (10000 * (0.1 + 0.05 - 0.1 * 0.05)), ownerGroupOr.getStatistics().getRowCount(), 0.001); } @@ -197,14 +197,14 @@ public void testFilterOutofRange() { GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup)); Group ownerGroup = newGroup(); groupExpression.setOwnerGroup(ownerGroup); - StatsCalculator.estimate(groupExpression); + StatsCalculator.estimate(groupExpression, null); Assertions.assertEquals(0, ownerGroup.getStatistics().getRowCount(), 0.001); LogicalFilter logicalFilterOr = new LogicalFilter<>(or, groupPlan); GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup)); Group ownerGroupOr = newGroup(); groupExpressionOr.setOwnerGroup(ownerGroupOr); - StatsCalculator.estimate(groupExpressionOr); + StatsCalculator.estimate(groupExpressionOr, null); Assertions.assertEquals(0, ownerGroupOr.getStatistics().getRowCount(), 0.001); } // TODO: temporary disable this test, until we could get column stats @@ -258,7 +258,7 @@ public void testOlapScan(@Mocked ConnectContext context) { GroupExpression groupExpression = new GroupExpression(logicalOlapScan1, ImmutableList.of(childGroup)); Group ownerGroup = newGroup(); groupExpression.setOwnerGroup(ownerGroup); - StatsCalculator.estimate(groupExpression); + StatsCalculator.estimate(groupExpression, null); Statistics stats = ownerGroup.getStatistics(); Assertions.assertEquals(1, stats.columnStatistics().size()); Assertions.assertNotNull(stats.columnStatistics().get(slot1)); @@ -288,7 +288,7 @@ public void testLimit() { GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup)); Group ownerGroup = newGroup(); ownerGroup.addGroupExpression(groupExpression); - StatsCalculator.estimate(groupExpression); + StatsCalculator.estimate(groupExpression, null); Statistics limitStats = ownerGroup.getStatistics(); Assertions.assertEquals(1, limitStats.getRowCount()); ColumnStatistic slot1Stats = limitStats.columnStatistics().get(slot1); @@ -318,7 +318,7 @@ public void testTopN() { GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup)); Group ownerGroup = newGroup(); ownerGroup.addGroupExpression(groupExpression); - StatsCalculator.estimate(groupExpression); + StatsCalculator.estimate(groupExpression, null); Statistics topNStats = ownerGroup.getStatistics(); Assertions.assertEquals(1, topNStats.getRowCount()); ColumnStatistic slot1Stats = topNStats.columnStatistics().get(slot1);