Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.apache.doris.catalog.Database;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.Table;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundRelation;
import org.apache.doris.nereids.jobs.Job;
Expand All @@ -34,6 +35,7 @@
import org.apache.doris.nereids.jobs.scheduler.JobStack;
import org.apache.doris.nereids.jobs.scheduler.ScheduleContext;
import org.apache.doris.nereids.jobs.scheduler.SimpleJobScheduler;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.processor.post.RuntimeFilterContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
Expand All @@ -44,6 +46,7 @@
import org.apache.doris.nereids.rules.analysis.BindRelation.CustomTableResolver;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTE;
Expand All @@ -54,6 +57,8 @@
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
Expand All @@ -63,6 +68,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
Expand Down Expand Up @@ -103,6 +109,9 @@ public class CascadesContext implements ScheduleContext {
private Map<Integer, Set<Expression>> consumerIdToFilters = new HashMap<>();
private Map<CTEId, Set<Integer>> cteIdToConsumerUnderProjects = new HashMap<>();

// Used to update consumer's stats
private Map<CTEId, List<Pair<Map<Slot, Slot>, Group>>> cteIdToConsumerGroup = new HashMap<>();

public CascadesContext(Plan plan, Memo memo, StatementContext statementContext,
PhysicalProperties requestProperties) {
this(plan, memo, statementContext, new CTEContext(), requestProperties);
Expand Down Expand Up @@ -565,4 +574,25 @@ public boolean couldPruneColumnOnProducer(CTEId cteId) {
Set<Integer> consumerIds = this.cteIdToConsumerUnderProjects.get(cteId);
return consumerIds.size() == this.cteIdToConsumers.get(cteId).size();
}

public void addCTEConsumerGroup(CTEId cteId, Group g, Map<Slot, Slot> producerSlotToConsumerSlot) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups =
this.cteIdToConsumerGroup.computeIfAbsent(cteId, k -> new ArrayList<>());
consumerGroups.add(Pair.of(producerSlotToConsumerSlot, g));
}

/**
* Update CTE consumer group as producer's stats update
*/
public void updateConsumerStats(CTEId cteId, Statistics statistics) {
List<Pair<Map<Slot, Slot>, Group>> consumerGroups = this.cteIdToConsumerGroup.get(cteId);
for (Pair<Map<Slot, Slot>, Group> p : consumerGroups) {
Map<Slot, Slot> producerSlotToConsumerSlot = p.first;
Statistics updatedConsumerStats = new Statistics(statistics);
for (Entry<Expression, ColumnStatistic> entry : statistics.columnStatistics().entrySet()) {
updatedConsumerStats.addColumnStats(producerSlotToConsumerSlot.get(entry.getKey()), entry.getValue());
}
p.value().setStatistics(updatedConsumerStats);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ private boolean calculateEnforce(List<PhysicalProperties> requestChildrenPropert
StatsCalculator statsCalculator = StatsCalculator.estimate(groupExpression,
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump());
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
context.getCascadesContext());
if (!context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump()
&& context.getCascadesContext().getConnectContext().getSessionVariable().isEnableMinidump()) {
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ public void execute() {
context.getCascadesContext().getConnectContext().getSessionVariable().getForbidUnknownColStats(),
context.getCascadesContext().getConnectContext().getTotalColumnStatisticMap(),
context.getCascadesContext().getConnectContext().getSessionVariable().isPlayNereidsDump(),
cteIdToStats);
cteIdToStats,
context.getCascadesContext());
STATS_STATE_TRACER.log(StatsStateEvent.of(groupExpression,
groupExpression.getOwnerGroup().getStatistics()));
if (ConnectContext.get().getSessionVariable().isEnableMinidump()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Config;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
Expand Down Expand Up @@ -158,14 +159,17 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {

private Map<CTEId, Statistics> cteIdToStats;

private CascadesContext cascadesContext;

private StatsCalculator(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<CTEId, Statistics> cteIdToStats) {
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
this.groupExpression = groupExpression;
this.forbidUnknownColStats = forbidUnknownColStats;
this.totalColumnStatisticMap = columnStatisticMap;
this.isPlayNereidsDump = isPlayNereidsDump;
this.cteIdToStats = Objects.requireNonNull(cteIdToStats, "CTEIdToStats can't be null");
this.cascadesContext = context;
}

public Map<String, Histogram> getTotalHistogramMap() {
Expand All @@ -189,25 +193,26 @@ public void setTotalColumnStatisticMap(Map<String, ColumnStatistic> totalColumnS
*/
public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump,
Map<CTEId, Statistics> cteIdToStats) {
Map<CTEId, Statistics> cteIdToStats, CascadesContext context) {
StatsCalculator statsCalculator = new StatsCalculator(
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats);
groupExpression, forbidUnknownColStats, columnStatisticMap, isPlayNereidsDump, cteIdToStats, context);
statsCalculator.estimate();
return statsCalculator;
}

public static StatsCalculator estimate(GroupExpression groupExpression, boolean forbidUnknownColStats,
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump) {
Map<String, ColumnStatistic> columnStatisticMap, boolean isPlayNereidsDump, CascadesContext context) {
return StatsCalculator.estimate(groupExpression,
forbidUnknownColStats,
columnStatisticMap,
isPlayNereidsDump,
new HashMap<>());
new HashMap<>(), context);
}

public static void estimate(GroupExpression groupExpression) {
// For unit test only
public static void estimate(GroupExpression groupExpression, CascadesContext context) {
StatsCalculator statsCalculator = new StatsCalculator(groupExpression, false,
new HashMap<>(), false, Collections.EMPTY_MAP);
new HashMap<>(), false, Collections.EMPTY_MAP, context);
statsCalculator.estimate();
}

Expand Down Expand Up @@ -997,6 +1002,8 @@ public Statistics visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cte
@Override
public Statistics visitLogicalCTEConsumer(LogicalCTEConsumer cteConsumer, Void context) {
CTEId cteId = cteConsumer.getCteId();
cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(),
cteConsumer.getProducerToConsumerOutputMap());
Statistics prodStats = cteIdToStats.get(cteId);
Preconditions.checkArgument(prodStats != null, String.format("Stats for CTE: %s not found", cteId));
Statistics consumerStats = new Statistics(prodStats.getRowCount(), new HashMap<>());
Expand All @@ -1021,11 +1028,14 @@ public Statistics visitPhysicalCTEProducer(PhysicalCTEProducer<? extends Plan> c
Void context) {
Statistics statistics = groupExpression.childStatistics(0);
cteIdToStats.put(cteProducer.getCteId(), statistics);
cascadesContext.updateConsumerStats(cteProducer.getCteId(), statistics);
return statistics;
}

@Override
public Statistics visitPhysicalCTEConsumer(PhysicalCTEConsumer cteConsumer, Void context) {
cascadesContext.addCTEConsumerGroup(cteConsumer.getCteId(), groupExpression.getOwnerGroup(),
cteConsumer.getProducerToConsumerSlotMap());
CTEId cteId = cteConsumer.getCteId();
Statistics prodStats = cteIdToStats.get(cteId);
if (prodStats == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,14 @@ public void testFilter() {
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Assertions.assertEquals((10000 * 0.1 * 0.05), ownerGroup.getStatistics().getRowCount(), 0.001);

LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
Group ownerGroupOr = newGroup();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator.estimate(groupExpressionOr);
StatsCalculator.estimate(groupExpressionOr, null);
Assertions.assertEquals((long) (10000 * (0.1 + 0.05 - 0.1 * 0.05)),
ownerGroupOr.getStatistics().getRowCount(), 0.001);
}
Expand Down Expand Up @@ -197,14 +197,14 @@ public void testFilterOutofRange() {
GroupExpression groupExpression = new GroupExpression(logicalFilter, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Assertions.assertEquals(0, ownerGroup.getStatistics().getRowCount(), 0.001);

LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or, groupPlan);
GroupExpression groupExpressionOr = new GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
Group ownerGroupOr = newGroup();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator.estimate(groupExpressionOr);
StatsCalculator.estimate(groupExpressionOr, null);
Assertions.assertEquals(0, ownerGroupOr.getStatistics().getRowCount(), 0.001);
}
// TODO: temporary disable this test, until we could get column stats
Expand Down Expand Up @@ -258,7 +258,7 @@ public void testOlapScan(@Mocked ConnectContext context) {
GroupExpression groupExpression = new GroupExpression(logicalOlapScan1, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Statistics stats = ownerGroup.getStatistics();
Assertions.assertEquals(1, stats.columnStatistics().size());
Assertions.assertNotNull(stats.columnStatistics().get(slot1));
Expand Down Expand Up @@ -288,7 +288,7 @@ public void testLimit() {
GroupExpression groupExpression = new GroupExpression(logicalLimit, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Statistics limitStats = ownerGroup.getStatistics();
Assertions.assertEquals(1, limitStats.getRowCount());
ColumnStatistic slot1Stats = limitStats.columnStatistics().get(slot1);
Expand Down Expand Up @@ -318,7 +318,7 @@ public void testTopN() {
GroupExpression groupExpression = new GroupExpression(logicalTopN, ImmutableList.of(childGroup));
Group ownerGroup = newGroup();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
StatsCalculator.estimate(groupExpression, null);
Statistics topNStats = ownerGroup.getStatistics();
Assertions.assertEquals(1, topNStats.getRowCount());
ColumnStatistic slot1Stats = topNStats.columnStatistics().get(slot1);
Expand Down