diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 81903aa0ec669a..9956cbe318db0a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -2567,7 +2567,7 @@ public PlanFragment visitPhysicalRepeat(PhysicalRepeat repeat, P // cube and rollup already convert to grouping sets in LogicalPlanBuilder.withAggregate() GroupingInfo groupingInfo = new GroupingInfo(outputTuple, preRepeatExprs); - List> repeatSlotIdList = repeat.computeRepeatSlotIdList(getSlotIds(outputTuple)); + List> repeatSlotIdList = repeat.computeRepeatSlotIdList(getSlotIds(outputTuple), outputSlots); Set allSlotId = repeatSlotIdList.stream() .flatMap(Set::stream) .collect(ImmutableSet.toImmutableSet()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index af86d18d494964..78ad895d0120bd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -608,6 +608,7 @@ import org.apache.doris.nereids.trees.plans.algebra.Aggregate; import org.apache.doris.nereids.trees.plans.algebra.InlineTable; import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation; +import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType; import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; import org.apache.doris.nereids.trees.plans.commands.AddConstraintCommand; import org.apache.doris.nereids.trees.plans.commands.AdminCancelRebalanceDiskCommand; @@ -4753,15 +4754,15 @@ private LogicalPlan withAggregate(LogicalPlan input, SelectColumnClauseContext s for (GroupingSetContext groupingSetContext : groupingElementContext.groupingSet()) { groupingSets.add(visit(groupingSetContext.expression(), Expression.class)); } - return new LogicalRepeat<>(groupingSets.build(), namedExpressions, input); + return new LogicalRepeat<>(groupingSets.build(), namedExpressions, RepeatType.GROUPING_SETS, input); } else if (groupingElementContext.CUBE() != null) { List cubeExpressions = visit(groupingElementContext.expression(), Expression.class); List> groupingSets = ExpressionUtils.cubeToGroupingSets(cubeExpressions); - return new LogicalRepeat<>(groupingSets, namedExpressions, input); + return new LogicalRepeat<>(groupingSets, namedExpressions, RepeatType.CUBE, input); } else if (groupingElementContext.ROLLUP() != null && groupingElementContext.WITH() == null) { List rollupExpressions = visit(groupingElementContext.expression(), Expression.class); List> groupingSets = ExpressionUtils.rollupToGroupingSets(rollupExpressions); - return new LogicalRepeat<>(groupingSets, namedExpressions, input); + return new LogicalRepeat<>(groupingSets, namedExpressions, RepeatType.ROLLUP, input); } else { List groupKeyWithOrders = visit(groupingElementContext.expressionWithOrder(), GroupKeyWithOrder.class); @@ -4775,7 +4776,7 @@ private LogicalPlan withAggregate(LogicalPlan input, SelectColumnClauseContext s } if (groupingElementContext.ROLLUP() != null) { List> groupingSets = ExpressionUtils.rollupToGroupingSets(groupByExpressions); - return new LogicalRepeat<>(groupingSets, namedExpressions, input); + return new LogicalRepeat<>(groupingSets, namedExpressions, RepeatType.ROLLUP, input); } else { return new LogicalAggregate<>(groupByExpressions, namedExpressions, input); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java index 87ddc7a0ca5220..211f2578721852 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewAggregateRule.java @@ -264,7 +264,8 @@ protected LogicalAggregate aggregateRewriteByView( } } LogicalRepeat repeat = new LogicalRepeat<>(rewrittenGroupSetsExpressions, - finalOutputExpressions, queryStructInfo.getGroupingId().get(), tempRewritedPlan); + finalOutputExpressions, queryStructInfo.getGroupingId().get(), + queryAggregate.getSourceRepeat().get().getRepeatType(), tempRewritedPlan); return NormalizeRepeat.doNormalize(repeat); } return new LogicalAggregate<>(finalGroupExpressions, finalOutputExpressions, tempRewritedPlan); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java index ed94fa730ca450..de9526005d0983 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/SplitAggWithoutDistinct.java @@ -96,7 +96,8 @@ private List implementOnePhase(LogicalAggregate logicalAgg } ); AggregateParam param = new AggregateParam(AggPhase.GLOBAL, AggMode.INPUT_TO_RESULT, !skipRegulator(logicalAgg)); - return ImmutableList.of(new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), aggOutput, param, + return ImmutableList.of(new PhysicalHashAggregate<>(logicalAgg.getGroupByExpressions(), aggOutput, + logicalAgg.getPartitionExpressions(), param, AggregateUtils.maybeUsingStreamAgg(logicalAgg.getGroupByExpressions(), param), null, logicalAgg.child())); } @@ -159,7 +160,7 @@ public Void visitSessionVarGuardExpr(SessionVarGuardExpr expr, Map(aggregate.getGroupByExpressions(), - globalAggOutput, bufferToResultParam, + globalAggOutput, aggregate.getPartitionExpressions(), bufferToResultParam, AggregateUtils.maybeUsingStreamAgg(aggregate.getGroupByExpressions(), bufferToResultParam), aggregate.getLogicalProperties(), localAgg)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java index 6f6a7f373a5634..fd40b635ffde11 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java @@ -19,6 +19,7 @@ import org.apache.doris.nereids.jobs.JobContext; import org.apache.doris.nereids.rules.rewrite.DistinctAggStrategySelector.DistinctSelectorContext; +import org.apache.doris.nereids.rules.rewrite.StatsDerive.DeriveContext; import org.apache.doris.nereids.trees.copier.DeepCopierContext; import org.apache.doris.nereids.trees.copier.LogicalPlanDeepCopier; import org.apache.doris.nereids.trees.expressions.Alias; @@ -50,6 +51,10 @@ import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.Statistics; +import org.apache.doris.statistics.util.StatisticsUtil; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -61,6 +66,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.TreeMap; /** * This rule will rewrite grouping sets. eg: @@ -119,13 +125,13 @@ public Plan visitLogicalCTEAnchor( @Override public Plan visitLogicalAggregate(LogicalAggregate aggregate, DistinctSelectorContext ctx) { aggregate = visitChildren(this, aggregate, ctx); - int maxGroupIndex = canOptimize(aggregate); + int maxGroupIndex = canOptimize(aggregate, ctx.cascadesContext.getConnectContext()); if (maxGroupIndex < 0) { return aggregate; } Map preToProducerSlotMap = new HashMap<>(); LogicalCTEProducer> producer = constructProducer(aggregate, maxGroupIndex, ctx, - preToProducerSlotMap); + preToProducerSlotMap, ctx.cascadesContext.getConnectContext()); LogicalCTEConsumer aggregateConsumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), producer.getCteId(), "", producer); LogicalCTEConsumer directConsumer = new LogicalCTEConsumer(ctx.statementContext.getNextRelationId(), @@ -276,6 +282,8 @@ private LogicalAggregate constructAgg(LogicalAggregate agg replacedExpr.toSlot()); } } + // NOTE: shuffle key selection is applied on the pre-agg (producer) side by setting + // LogicalAggregate.partitionExpressions. See constructProducer(). return new LogicalAggregate<>(topAggGby, topAggOutput, Optional.of(newRepeat), newRepeat); } @@ -369,16 +377,14 @@ private LogicalUnion constructUnion(LogicalPlan aggregateProject, LogicalPlan di * Determine if optimization is possible; if so, return the index of the largest group. * The optimization requires: * 1. The aggregate's child must be a LogicalRepeat - * 2. All aggregate functions must be Sum, Min, or Max (non-distinct) - * 3. No GroupingScalarFunction in repeat output - * 4. More than 3 grouping sets - * 5. There exists a grouping set that contains all other grouping sets - * + * 2. All aggregate functions must be in SUPPORT_AGG_FUNCTIONS. + * 3. More than 3 grouping sets + * 4. There exists a grouping set that contains all other grouping sets * @param aggregate the aggregate plan to check * @return value -1 means can not be optimized, values other than -1 * represent the index of the set that contains all other sets */ - private int canOptimize(LogicalAggregate aggregate) { + private int canOptimize(LogicalAggregate aggregate, ConnectContext connectContext) { Plan aggChild = aggregate.child(); if (!(aggChild instanceof LogicalRepeat)) { return -1; @@ -398,7 +404,7 @@ private int canOptimize(LogicalAggregate aggregate) { // This is an empirical threshold: when there are too few grouping sets, // the overhead of creating CTE and union may outweigh the benefits. // The value 3 is chosen heuristically based on practical experience. - if (groupingSets.size() <= 3) { + if (groupingSets.size() <= connectContext.getSessionVariable().decomposeRepeatThreshold) { return -1; } return findMaxGroupingSetIndex(groupingSets); @@ -426,6 +432,9 @@ private int findMaxGroupingSetIndex(List> groupingSets) { maxGroupIndex = i; } } + if (groupingSets.get(maxGroupIndex).isEmpty()) { + return -1; + } // Second pass: verify that the max-size grouping set contains all other grouping sets ImmutableSet maxGroup = ImmutableSet.copyOf(groupingSets.get(maxGroupIndex)); for (int i = 0; i < groupingSets.size(); ++i) { @@ -450,7 +459,8 @@ private int findMaxGroupingSetIndex(List> groupingSets) { * @return a LogicalCTEProducer containing the pre-aggregation */ private LogicalCTEProducer> constructProducer(LogicalAggregate aggregate, - int maxGroupIndex, DistinctSelectorContext ctx, Map preToCloneSlotMap) { + int maxGroupIndex, DistinctSelectorContext ctx, Map preToCloneSlotMap, + ConnectContext connectContext) { LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); List maxGroupByList = repeat.getGroupingSets().get(maxGroupIndex); List originAggOutputs = aggregate.getOutputExpressions(); @@ -469,6 +479,11 @@ private LogicalCTEProducer> constructProducer(LogicalAggr } LogicalAggregate preAgg = new LogicalAggregate<>(maxGroupByList, orderedAggOutputs, repeat.child()); + Optional> partitionExprs = choosePreAggShuffleKeyPartitionExprs( + repeat, maxGroupIndex, maxGroupByList, connectContext); + if (partitionExprs.isPresent() && !partitionExprs.get().isEmpty()) { + preAgg = preAgg.withPartitionExpressions(partitionExprs); + } LogicalAggregate preAggClone = (LogicalAggregate) LogicalPlanDeepCopier.INSTANCE .deepCopy(preAgg, new DeepCopierContext()); for (int i = 0; i < preAgg.getOutputExpressions().size(); ++i) { @@ -480,6 +495,95 @@ private LogicalCTEProducer> constructProducer(LogicalAggr return producer; } + /** + * Choose partition expressions (shuffle key) for pre-aggregation (producer agg). + */ + private Optional> choosePreAggShuffleKeyPartitionExprs( + LogicalRepeat repeat, int maxGroupIndex, List maxGroupByList, + ConnectContext connectContext) { + int idx = connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup; + if (idx >= 0 && idx < maxGroupByList.size()) { + return Optional.of(ImmutableList.of(maxGroupByList.get(idx))); + } + if (repeat.child().getStats() == null) { + repeat.child().accept(new StatsDerive(false), new DeriveContext()); + } + Statistics inputStats = repeat.child().getStats(); + if (inputStats == null) { + return Optional.empty(); + } + int beNumber = Math.max(1, connectContext.getEnv().getClusterInfo().getBackendsNumber(true)); + int parallelInstance = Math.max(1, connectContext.getSessionVariable().getParallelExecInstanceNum()); + int totalInstanceNum = beNumber * parallelInstance; + Optional chosen; + switch (repeat.getRepeatType()) { + case CUBE: + // Prefer larger NDV to improve balance + chosen = chooseOneBalancedKey(maxGroupByList, inputStats, totalInstanceNum); + break; + case GROUPING_SETS: + chosen = chooseByAppearanceThenNdv(repeat.getGroupingSets(), maxGroupIndex, maxGroupByList, + inputStats, totalInstanceNum); + break; + case ROLLUP: + chosen = chooseOneBalancedKey(maxGroupByList, inputStats, totalInstanceNum); + break; + default: + chosen = Optional.empty(); + } + return chosen.map(ImmutableList::of); + } + + private Optional chooseOneBalancedKey(List candidates, Statistics inputStats, + int totalInstanceNum) { + if (inputStats == null) { + return Optional.empty(); + } + for (Expression candidate : candidates) { + ColumnStatistic columnStatistic = inputStats.findColumnStatistics(candidate); + if (columnStatistic == null || columnStatistic.isUnKnown()) { + continue; + } + if (StatisticsUtil.isBalanced(columnStatistic, inputStats.getRowCount(), totalInstanceNum)) { + return Optional.of(candidate); + } + } + return Optional.empty(); + } + + /** + * GROUPING_SETS: prefer keys appearing in more (non-max) grouping sets, tie-break by larger NDV. + */ + private Optional chooseByAppearanceThenNdv(List> groupingSets, int maxGroupIndex, + List candidates, Statistics inputStats, int totalInstanceNum) { + Map appearCount = new HashMap<>(); + for (Expression c : candidates) { + appearCount.put(c, 0); + } + for (int i = 0; i < groupingSets.size(); i++) { + if (i == maxGroupIndex) { + continue; + } + List set = groupingSets.get(i); + for (Expression c : candidates) { + if (set.contains(c)) { + appearCount.put(c, appearCount.get(c) + 1); + } + } + } + TreeMap> countToCandidate = new TreeMap<>(); + for (Map.Entry entry : appearCount.entrySet()) { + countToCandidate.computeIfAbsent(entry.getValue(), v -> new ArrayList<>()).add(entry.getKey()); + } + for (Map.Entry> entry : countToCandidate.descendingMap().entrySet()) { + Optional chosen = chooseOneBalancedKey(entry.getValue(), inputStats, totalInstanceNum); + if (chosen.isPresent()) { + return chosen; + } + } + return Optional.empty(); + } + /** * Construct a new LogicalRepeat with reduced grouping sets and replaced expressions. * The grouping sets and output expressions are replaced using the slot mapping from producer to consumer. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java index c969863493c1a0..7fe2dd26ae0dbb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java @@ -193,9 +193,18 @@ public Plan visitLogicalAggregate(LogicalAggregate aggregate, De outputExpressions, child); Optional> childRepeat = copiedAggregate.collectFirst(LogicalRepeat.class::isInstance); - return childRepeat.isPresent() ? aggregate.withChildGroupByAndOutputAndSourceRepeat( - groupByExpressions, outputExpressions, child, childRepeat) - : aggregate.withChildGroupByAndOutput(groupByExpressions, outputExpressions, child); + List partitionExpressions = ImmutableList.of(); + if (aggregate.getPartitionExpressions().isPresent()) { + partitionExpressions = aggregate.getPartitionExpressions().get().stream() + .map(k -> ExpressionDeepCopier.INSTANCE.deepCopy(k, context)) + .collect(ImmutableList.toImmutableList()); + } + Optional> optionalPartitionExpressions = partitionExpressions.isEmpty() + ? Optional.empty() : Optional.of(partitionExpressions); + return childRepeat.isPresent() ? aggregate.withChildGroupByAndOutputAndSourceRepeatAndPartitionExpr( + groupByExpressions, outputExpressions, optionalPartitionExpressions, child, childRepeat) + : aggregate.withChildGroupByAndOutputAndPartitionExpr(groupByExpressions, outputExpressions, + optionalPartitionExpressions, child); } @Override @@ -211,7 +220,7 @@ public Plan visitLogicalRepeat(LogicalRepeat repeat, DeepCopierC .collect(ImmutableList.toImmutableList()); SlotReference groupingId = (SlotReference) ExpressionDeepCopier.INSTANCE .deepCopy(repeat.getGroupingId().get(), context); - return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, child); + return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, repeat.getRepeatType(), child); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java index e35b48073b591a..7a7f1f1f3c2772 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Repeat.java @@ -18,16 +18,15 @@ package org.apache.doris.nereids.trees.plans.algebra; import org.apache.doris.nereids.exceptions.AnalysisException; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.util.BitUtils; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; @@ -117,14 +116,35 @@ default List> computeGroupingFunctionsValues() { /** * flatten the grouping sets and build to a GroupingSetShapes. + * This method ensures that all expressions referenced by grouping functions are included + * in the flattenGroupingSetExpression, even if they are not in any grouping set. + * This is necessary for optimization scenarios where some expressions may only exist + * in the maximum grouping set that was removed during optimization. */ default GroupingSetShapes toShapes() { - Set flattenGroupingSet = ImmutableSet.copyOf(ExpressionUtils.flatExpressions(getGroupingSets())); + // Collect all expressions referenced by grouping functions to ensure they are included + // in flattenGroupingSetExpression, even if they are not in any grouping set. + // This maintains semantic constraints while allowing optimization. + List groupingFunctions = ExpressionUtils.collectToList( + getOutputExpressions(), GroupingScalarFunction.class::isInstance); + Set groupingFunctionArgs = Sets.newLinkedHashSet(); + for (GroupingScalarFunction function : groupingFunctions) { + groupingFunctionArgs.addAll(function.getArguments()); + } + // Merge grouping set expressions with grouping function arguments + // Use LinkedHashSet to preserve order: grouping sets first, then grouping function args + Set flattenGroupingSet = Sets.newLinkedHashSet(getGroupByExpressions()); + for (Expression arg : groupingFunctionArgs) { + if (!flattenGroupingSet.contains(arg)) { + flattenGroupingSet.add(arg); + } + } List shapes = Lists.newArrayList(); for (List groupingSet : getGroupingSets()) { List shouldBeErasedToNull = Lists.newArrayListWithCapacity(flattenGroupingSet.size()); - for (Expression groupingSetExpression : flattenGroupingSet) { - shouldBeErasedToNull.add(!groupingSet.contains(groupingSetExpression)); + for (Expression expression : flattenGroupingSet) { + // If expression is not in the current grouping set, it should be erased to null + shouldBeErasedToNull.add(!groupingSet.contains(expression)); } shapes.add(new GroupingSetShape(shouldBeErasedToNull)); } @@ -140,8 +160,8 @@ default GroupingSetShapes toShapes() { * * return: [(4, 3), (3)] */ - default List> computeRepeatSlotIdList(List slotIdList) { - List> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput(); + default List> computeRepeatSlotIdList(List slotIdList, List outputSlots) { + List> groupingSetsIndexesInOutput = getGroupingSetsIndexesInOutput(outputSlots); List> repeatSlotIdList = Lists.newArrayList(); for (Set groupingSetIndex : groupingSetsIndexesInOutput) { // keep order @@ -160,8 +180,8 @@ default List> computeRepeatSlotIdList(List slotIdList) { * e.g. groupingSets=((b, a), (a)), output=[a, b] * return ((1, 0), (1)) */ - default List> getGroupingSetsIndexesInOutput() { - Map indexMap = indexesOfOutput(); + default List> getGroupingSetsIndexesInOutput(List outputSlots) { + Map indexMap = indexesOfOutput(outputSlots); List> groupingSetsIndex = Lists.newArrayList(); List> groupingSets = getGroupingSets(); @@ -184,23 +204,22 @@ default List> getGroupingSetsIndexesInOutput() { /** * indexesOfOutput: get the indexes which mapping from the expression to the index in the output. * - * e.g. output=[a + 1, b + 2, c] + * e.g. outputSlots=[a + 1, b + 2, c] * * return the map( * `a + 1`: 0, * `b + 2`: 1, * `c`: 2 * ) + * + * Use outputSlots in physicalPlanTranslator instead of getOutputExpressions() in this method, + * because the outputSlots have same order with slotIdList. */ - default Map indexesOfOutput() { + static Map indexesOfOutput(List outputSlots) { Map indexes = Maps.newLinkedHashMap(); - List outputs = getOutputExpressions(); - for (int i = 0; i < outputs.size(); i++) { - NamedExpression output = outputs.get(i); + for (int i = 0; i < outputSlots.size(); i++) { + NamedExpression output = outputSlots.get(i); indexes.put(output, i); - if (output instanceof Alias) { - indexes.put(((Alias) output).child(), i); - } } return indexes; } @@ -302,4 +321,11 @@ public String toString() { return "GroupingSetShape(shouldBeErasedToNull=" + shouldBeErasedToNull + ")"; } } + + /** RepeatType */ + enum RepeatType { + ROLLUP, + CUBE, + GROUPING_SETS + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index d063ccb40aa3ed..b7b4e4f756b456 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -79,6 +79,7 @@ public class LogicalAggregate private final boolean generated; private final boolean hasPushed; private final boolean withInProjection; + private final Optional> partitionExpressions; /** * Desc: Constructor for LogicalAggregate. @@ -97,19 +98,19 @@ public LogicalAggregate( public LogicalAggregate(List namedExpressions, boolean generated, CHILD_TYPE child) { this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, true, generated, false, true, Optional.empty(), - Optional.empty(), Optional.empty(), child); + Optional.empty(), Optional.empty(), Optional.empty(), child); } public LogicalAggregate(List namedExpressions, boolean generated, boolean hasPushed, CHILD_TYPE child) { this(ImmutableList.copyOf(namedExpressions), namedExpressions, false, true, generated, hasPushed, true, - Optional.empty(), Optional.empty(), Optional.empty(), child); + Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), child); } public LogicalAggregate(List groupByExpressions, List outputExpressions, boolean ordinalIsResolved, CHILD_TYPE child) { this(groupByExpressions, outputExpressions, false, ordinalIsResolved, false, false, true, Optional.empty(), - Optional.empty(), Optional.empty(), child); + Optional.empty(), Optional.empty(), Optional.empty(), child); } /** @@ -131,7 +132,7 @@ public LogicalAggregate( Optional> sourceRepeat, CHILD_TYPE child) { this(groupByExpressions, outputExpressions, normalized, false, false, false, true, sourceRepeat, - Optional.empty(), Optional.empty(), child); + Optional.empty(), Optional.empty(), Optional.empty(), child); } /** @@ -148,6 +149,7 @@ private LogicalAggregate( Optional> sourceRepeat, Optional groupExpression, Optional logicalProperties, + Optional> partitionExpressions, CHILD_TYPE child) { super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties, child); this.groupByExpressions = ImmutableList.copyOf(groupByExpressions); @@ -162,6 +164,7 @@ private LogicalAggregate( this.hasPushed = hasPushed; this.sourceRepeat = Objects.requireNonNull(sourceRepeat, "sourceRepeat cannot be null"); this.withInProjection = withInProjection; + this.partitionExpressions = partitionExpressions; } @Override @@ -280,6 +283,16 @@ public List getExpressions() { .build(); } + public Optional> getPartitionExpressions() { + return partitionExpressions; + } + + public LogicalAggregate withPartitionExpressions(Optional> newPartitionExpressions) { + return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated, + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newPartitionExpressions, + child()); + } + public boolean isNormalized() { return normalized; } @@ -304,26 +317,29 @@ public boolean equals(Object o) { && normalized == that.normalized && ordinalIsResolved == that.ordinalIsResolved && generated == that.generated - && Objects.equals(sourceRepeat, that.sourceRepeat); + && Objects.equals(sourceRepeat, that.sourceRepeat) + && Objects.equals(partitionExpressions, that.partitionExpressions); } @Override public int hashCode() { - return Objects.hash(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, sourceRepeat); + return Objects.hash(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, sourceRepeat, + partitionExpressions); } @Override public LogicalAggregate withChildren(List children) { Preconditions.checkArgument(children.size() == 1); return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), children.get(0)); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + children.get(0)); } @Override public LogicalAggregate withGroupExpression(Optional groupExpression) { return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, - sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0)); + hasPushed, withInProjection, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), + partitionExpressions, children.get(0)); } @Override @@ -331,39 +347,52 @@ public Plan withGroupExprLogicalPropChildren(Optional groupExpr Optional logicalProperties, List children) { Preconditions.checkArgument(children.size() == 1); return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, - sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0)); + hasPushed, withInProjection, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), + partitionExpressions, children.get(0)); } public LogicalAggregate withGroupByAndOutput(List groupByExprList, List outputExpressionList) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), child()); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + child()); } public LogicalAggregate withGroupBy(List groupByExprList) { return new LogicalAggregate<>(groupByExprList, outputExpressions, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), child()); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + child()); } public LogicalAggregate withChildGroupByAndOutput(List groupByExprList, List outputExpressionList, Plan newChild) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newChild); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + newChild); + } + + public LogicalAggregate withChildGroupByAndOutputAndPartitionExpr(List groupByExprList, + List outputExpressionList, Optional> partitionExpressions, + Plan newChild) { + return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), + partitionExpressions, newChild); } - public LogicalAggregate withChildGroupByAndOutputAndSourceRepeat(List groupByExprList, - List outputExpressionList, Plan newChild, - Optional> sourceRepeat) { + public LogicalAggregate withChildGroupByAndOutputAndSourceRepeatAndPartitionExpr( + List groupByExprList, + List outputExpressionList, Optional> partitionExpressions, Plan newChild, + Optional> sourceRepeat) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newChild); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), + partitionExpressions, newChild); } public LogicalAggregate withChildAndOutput(CHILD_TYPE child, List outputExpressionList) { return new LogicalAggregate<>(groupByExpressions, outputExpressionList, normalized, ordinalIsResolved, generated, hasPushed, withInProjection, sourceRepeat, Optional.empty(), - Optional.empty(), child); + Optional.empty(), partitionExpressions, child); } @Override @@ -374,30 +403,33 @@ public List getOutputs() { @Override public LogicalAggregate withAggOutput(List newOutput) { return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), child()); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + child()); } public LogicalAggregate withAggOutputChild(List newOutput, Plan newChild) { return new LogicalAggregate<>(groupByExpressions, newOutput, normalized, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), newChild); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + newChild); } public LogicalAggregate withNormalized(List normalizedGroupBy, List normalizedOutput, Plan normalizedChild) { return new LogicalAggregate<>(normalizedGroupBy, normalizedOutput, true, ordinalIsResolved, generated, - hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), normalizedChild); + hasPushed, withInProjection, sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, + normalizedChild); } public LogicalAggregate withInProjection(boolean withInProjection) { return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated, hasPushed, withInProjection, - sourceRepeat, Optional.empty(), Optional.empty(), child()); + sourceRepeat, Optional.empty(), Optional.empty(), partitionExpressions, child()); } public LogicalAggregate withSourceRepeat(LogicalRepeat sourceRepeat) { return new LogicalAggregate<>(groupByExpressions, outputExpressions, normalized, ordinalIsResolved, generated, hasPushed, withInProjection, Optional.ofNullable(sourceRepeat), - Optional.empty(), Optional.empty(), child()); + Optional.empty(), Optional.empty(), partitionExpressions, child()); } private boolean isUniqueGroupByUnique(NamedExpression namedExpression) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java index a7ab86de4fbd44..7b10cd2ced4ca2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalRepeat.java @@ -57,6 +57,7 @@ public class LogicalRepeat extends LogicalUnary outputExpressions; private final Optional groupingId; private final boolean withInProjection; + private final RepeatType type; /** * Desc: Constructor for LogicalRepeat. @@ -64,8 +65,9 @@ public class LogicalRepeat extends LogicalUnary> groupingSets, List outputExpressions, + RepeatType type, CHILD_TYPE child) { - this(groupingSets, outputExpressions, Optional.empty(), child); + this(groupingSets, outputExpressions, Optional.empty(), type, child); } /** @@ -75,9 +77,10 @@ public LogicalRepeat( List> groupingSets, List outputExpressions, SlotReference groupingId, + RepeatType type, CHILD_TYPE child) { this(groupingSets, outputExpressions, Optional.empty(), Optional.empty(), - Optional.ofNullable(groupingId), true, child); + Optional.ofNullable(groupingId), true, type, child); } /** @@ -87,8 +90,9 @@ private LogicalRepeat( List> groupingSets, List outputExpressions, Optional groupingId, + RepeatType type, CHILD_TYPE child) { - this(groupingSets, outputExpressions, Optional.empty(), Optional.empty(), groupingId, true, child); + this(groupingSets, outputExpressions, Optional.empty(), Optional.empty(), groupingId, true, type, child); } /** @@ -96,7 +100,7 @@ private LogicalRepeat( */ private LogicalRepeat(List> groupingSets, List outputExpressions, Optional groupExpression, Optional logicalProperties, - Optional groupingId, boolean withInProjection, CHILD_TYPE child) { + Optional groupingId, boolean withInProjection, RepeatType type, CHILD_TYPE child) { super(PlanType.LOGICAL_REPEAT, groupExpression, logicalProperties, child); this.groupingSets = Objects.requireNonNull(groupingSets, "groupingSets can not be null") .stream() @@ -106,6 +110,7 @@ private LogicalRepeat(List> groupingSets, List Objects.requireNonNull(outputExpressions, "outputExpressions can not be null")); this.groupingId = groupingId; this.withInProjection = withInProjection; + this.type = type; } @Override @@ -122,6 +127,10 @@ public Optional getGroupingId() { return groupingId; } + public RepeatType getRepeatType() { + return type; + } + @Override public List getOutputs() { return outputExpressions; @@ -217,13 +226,13 @@ public int hashCode() { @Override public LogicalRepeat withChildren(List children) { Preconditions.checkArgument(children.size() == 1); - return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, children.get(0)); + return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, type, children.get(0)); } @Override public LogicalRepeat withGroupExpression(Optional groupExpression) { return new LogicalRepeat<>(groupingSets, outputExpressions, groupExpression, - Optional.of(getLogicalProperties()), groupingId, withInProjection, child()); + Optional.of(getLogicalProperties()), groupingId, withInProjection, type, child()); } @Override @@ -231,35 +240,35 @@ public Plan withGroupExprLogicalPropChildren(Optional groupExpr Optional logicalProperties, List children) { Preconditions.checkArgument(children.size() == 1); return new LogicalRepeat<>(groupingSets, outputExpressions, groupExpression, logicalProperties, - groupingId, withInProjection, children.get(0)); + groupingId, withInProjection, type, children.get(0)); } public LogicalRepeat withGroupSets(List> groupingSets) { - return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, child()); + return new LogicalRepeat<>(groupingSets, outputExpressions, groupingId, type, child()); } public LogicalRepeat withGroupSetsAndOutput(List> groupingSets, List outputExpressionList) { - return new LogicalRepeat<>(groupingSets, outputExpressionList, groupingId, child()); + return new LogicalRepeat<>(groupingSets, outputExpressionList, groupingId, type, child()); } @Override public LogicalRepeat withAggOutput(List newOutput) { - return new LogicalRepeat<>(groupingSets, newOutput, groupingId, child()); + return new LogicalRepeat<>(groupingSets, newOutput, groupingId, type, child()); } public LogicalRepeat withNormalizedExpr(List> groupingSets, List outputExpressionList, SlotReference groupingId, Plan child) { - return new LogicalRepeat<>(groupingSets, outputExpressionList, groupingId, child); + return new LogicalRepeat<>(groupingSets, outputExpressionList, groupingId, type, child); } public LogicalRepeat withAggOutputAndChild(List newOutput, Plan child) { - return new LogicalRepeat<>(groupingSets, newOutput, groupingId, child); + return new LogicalRepeat<>(groupingSets, newOutput, groupingId, type, child); } public LogicalRepeat withInProjection(boolean withInProjection) { return new LogicalRepeat<>(groupingSets, outputExpressions, - Optional.empty(), Optional.empty(), groupingId, withInProjection, child()); + Optional.empty(), Optional.empty(), groupingId, withInProjection, type, child()); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index d4bb2b3fa4658d..562fa431a3a036 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -830,6 +830,9 @@ public class SessionVariable implements Serializable, Writable { public static final String SKEW_REWRITE_JOIN_SALT_EXPLODE_FACTOR = "skew_rewrite_join_salt_explode_factor"; public static final String SKEW_REWRITE_AGG_BUCKET_NUM = "skew_rewrite_agg_bucket_num"; + public static final String DECOMPOSE_REPEAT_THRESHOLD = "decompose_repeat_threshold"; + public static final String DECOMPOSE_REPEAT_SHUFFLE_INDEX_IN_MAX_GROUP + = "decompose_repeat_shuffle_index_in_max_group"; public static final String HOT_VALUE_COLLECT_COUNT = "hot_value_collect_count"; @VariableMgr.VarAttr(name = HOT_VALUE_COLLECT_COUNT, needForward = true, @@ -3290,6 +3293,11 @@ public boolean isEnableESParallelScroll() { ) public boolean useV3StorageFormat = false; + @VariableMgr.VarAttr(name = DECOMPOSE_REPEAT_THRESHOLD) + public int decomposeRepeatThreshold = 3; + @VariableMgr.VarAttr(name = DECOMPOSE_REPEAT_SHUFFLE_INDEX_IN_MAX_GROUP) + public int decomposeRepeatShuffleIndexInMaxGroup = -1; + public static final String IGNORE_ICEBERG_DANGLING_DELETE = "ignore_iceberg_dangling_delete"; @VariableMgr.VarAttr(name = IGNORE_ICEBERG_DANGLING_DELETE, description = {"是否忽略 Iceberg 表中 dangling delete 文件对 COUNT(*) 统计信息的影响。" @@ -3303,6 +3311,7 @@ public boolean isEnableESParallelScroll() { + "to exclude the impact of dangling delete files."}) public boolean ignoreIcebergDanglingDelete = false; + // If this fe is in fuzzy mode, then will use initFuzzyModeVariables to generate some variables, // not the default value set in the code. @SuppressWarnings("checkstyle:Indentation") diff --git a/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java b/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java index 014c9ace70f887..6486a971da21f7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java +++ b/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java @@ -1286,6 +1286,8 @@ public static boolean canCollect() { * Get the map of column literal value and its row count percentage in the table. * The stringValues is like: * value1 :percent1 ;value2 :percent2 ;value3 :percent3 + * Result is ordered by percentage descending so that the hottest value is first. + * (GROUP_CONCAT in SQL does not guarantee order, so we sort here.) * @return Map of LiteralExpr -> percentage. */ public static LinkedHashMap getHotValues(String stringValues, Type type, double avgOccurrences) { @@ -1314,7 +1316,15 @@ public static LinkedHashMap getHotValues(String stringValues, Ty } } } + // Sort by percentage descending so that the hottest value is first. + // GROUP_CONCAT in SQL does not preserve subquery ORDER BY, so order is not guaranteed from DB. if (!ret.isEmpty()) { + List> entries = new ArrayList<>(ret.entrySet()); + entries.sort((a, b) -> Float.compare(b.getValue(), a.getValue())); + ret = Maps.newLinkedHashMap(); + for (Map.Entry e : entries) { + ret.put(e.getKey(), e.getValue()); + } return ret; } } catch (Exception e) { @@ -1322,4 +1332,23 @@ public static LinkedHashMap getHotValues(String stringValues, Ty } return null; } + + public static boolean isBalanced(ColumnStatistic columnStatistic, double rowCount, int instanceNum) { + double ndv = columnStatistic.ndv; + double maxHotValueCntIncludeNull; + Map hotValues = columnStatistic.getHotValues(); + // When hotValues not exist, or exist but unknown, treat nulls as the only hot value. + if (columnStatistic.getHotValues() == null || hotValues.isEmpty()) { + maxHotValueCntIncludeNull = columnStatistic.numNulls; + } else { + double rate = hotValues.values().iterator().next(); + maxHotValueCntIncludeNull = rate * rowCount > columnStatistic.numNulls + ? rate * rowCount : columnStatistic.numNulls; + } + double rowsPerInstance = (rowCount - maxHotValueCntIncludeNull) / instanceNum; + double balanceFactor = maxHotValueCntIncludeNull == 0 + ? Double.MAX_VALUE : rowsPerInstance / maxHotValueCntIncludeNull; + // The larger this factor is, the more balanced the data. + return balanceFactor > 2.0 && ndv > instanceNum * 3; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java index 556f5279412121..e2c874aad17731 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; import org.apache.doris.nereids.util.MemoPatternMatchSupported; @@ -44,6 +45,7 @@ public void testKeepNullableAfterNormalizeRepeat() { Plan plan = new LogicalRepeat<>( ImmutableList.of(ImmutableList.of(id), ImmutableList.of(name)), ImmutableList.of(idNotNull, alias), + RepeatType.GROUPING_SETS, scan1 ); PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) @@ -62,6 +64,7 @@ public void testEliminateRepeat() { Plan plan = new LogicalRepeat<>( ImmutableList.of(ImmutableList.of(id)), ImmutableList.of(idNotNull, alias), + RepeatType.GROUPING_SETS, scan1 ); PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) @@ -80,6 +83,7 @@ public void testNoEliminateRepeat() { Plan plan = new LogicalRepeat<>( ImmutableList.of(ImmutableList.of(id)), ImmutableList.of(idNotNull, alias), + RepeatType.GROUPING_SETS, scan1 ); PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java index c78394ce9ef141..af65a0f921b56a 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; @@ -39,6 +40,10 @@ import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.qe.ConnectContext; +import org.apache.doris.statistics.ColumnStatistic; +import org.apache.doris.statistics.ColumnStatisticBuilder; +import org.apache.doris.statistics.Statistics; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; @@ -50,6 +55,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; /** @@ -272,7 +278,7 @@ public void testGetNeedAddNullExpressions() throws Exception { @Test public void testCanOptimize() throws Exception { - Method method = rule.getClass().getDeclaredMethod("canOptimize", LogicalAggregate.class); + Method method = rule.getClass().getDeclaredMethod("canOptimize", LogicalAggregate.class, ConnectContext.class); method.setAccessible(true); SlotReference a = new SlotReference("a", IntegerType.INSTANCE); @@ -303,7 +309,7 @@ public void testCanOptimize() throws Exception { ImmutableList.of(a, b, c, d, sumAlias), repeat); - int result = (int) method.invoke(rule, aggregate); + int result = (int) method.invoke(rule, aggregate, connectContext); Assertions.assertEquals(0, result); // Test case 2: Child is not LogicalRepeat @@ -311,7 +317,7 @@ public void testCanOptimize() throws Exception { ImmutableList.of(a), ImmutableList.of(a, sumAlias), emptyRelation); - result = (int) method.invoke(rule, aggregateWithNonRepeat); + result = (int) method.invoke(rule, aggregateWithNonRepeat, connectContext); Assertions.assertEquals(-1, result); // Test case 3: Unsupported aggregate function (Avg) @@ -322,7 +328,7 @@ public void testCanOptimize() throws Exception { ImmutableList.of(a, b, c, d), ImmutableList.of(a, b, c, d, avgAlias), repeat); - result = (int) method.invoke(rule, aggregateWithCount); + result = (int) method.invoke(rule, aggregateWithCount, connectContext); Assertions.assertEquals(-1, result); // Test case 4: Grouping sets size <= 3 @@ -340,7 +346,7 @@ public void testCanOptimize() throws Exception { ImmutableList.of(a, b), ImmutableList.of(a, b, sumAlias), smallRepeat); - result = (int) method.invoke(rule, aggregateWithSmallRepeat); + result = (int) method.invoke(rule, aggregateWithSmallRepeat, connectContext); Assertions.assertEquals(-1, result); } @@ -369,6 +375,7 @@ public void testConstructUnion() throws Exception { groupingSets, (List) ImmutableList.of(a, b), new SlotReference("grouping_id", IntegerType.INSTANCE), + RepeatType.GROUPING_SETS, emptyRelation); LogicalAggregate aggregate = new LogicalAggregate<>( ImmutableList.of(a, b), @@ -391,7 +398,7 @@ public void testConstructUnion() throws Exception { @Test public void testConstructProducer() throws Exception { Method method = rule.getClass().getDeclaredMethod("constructProducer", - LogicalAggregate.class, int.class, DistinctSelectorContext.class, Map.class); + LogicalAggregate.class, int.class, DistinctSelectorContext.class, Map.class, ConnectContext.class); method.setAccessible(true); SlotReference a = new SlotReference("a", IntegerType.INSTANCE); @@ -412,7 +419,7 @@ public void testConstructProducer() throws Exception { LogicalRepeat repeat = new LogicalRepeat<>( groupingSets, (List) ImmutableList.of(a, b, c, d), - null, + RepeatType.GROUPING_SETS, emptyRelation); Sum sumFunc = new Sum(d); Alias sumAlias = new Alias(sumFunc, "sum_d"); @@ -423,7 +430,7 @@ public void testConstructProducer() throws Exception { Map preToCloneSlotMap = new HashMap<>(); LogicalCTEProducer> result = (LogicalCTEProducer>) - method.invoke(rule, aggregate, 0, ctx, preToCloneSlotMap); + method.invoke(rule, aggregate, 0, ctx, preToCloneSlotMap, connectContext); Assertions.assertNotNull(result); Assertions.assertNotNull(result.child()); @@ -459,6 +466,7 @@ public void testConstructRepeat() throws Exception { originalGroupingSets, (List) ImmutableList.of(a, b, c), new SlotReference("grouping_id", IntegerType.INSTANCE), + RepeatType.GROUPING_SETS, emptyRelation); List> newGroupingSets = ImmutableList.of( @@ -482,4 +490,154 @@ public void testConstructRepeat() throws Exception { Assertions.assertEquals(2, result.getGroupingSets().size()); Assertions.assertTrue(groupingFunctionSlots.isEmpty()); } + + @Test + public void testChoosePreAggShuffleKeyPartitionExprs() throws Exception { + Method method = rule.getClass().getDeclaredMethod("choosePreAggShuffleKeyPartitionExprs", + LogicalRepeat.class, int.class, List.class, org.apache.doris.qe.ConnectContext.class); + method.setAccessible(true); + + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE); + + List maxGroupByList = ImmutableList.of(a, b, c); + LogicalEmptyRelation emptyRelation = new LogicalEmptyRelation( + org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator.newRelationId(), + ImmutableList.of()); + List> groupingSets = ImmutableList.of( + ImmutableList.of(a, b, c), + ImmutableList.of(a, b), + ImmutableList.of(a) + ); + LogicalRepeat repeatRollup = new LogicalRepeat<>( + groupingSets, + (List) ImmutableList.of(a, b, c), + null, + RepeatType.ROLLUP, + emptyRelation); + LogicalRepeat repeatGroupingSets = new LogicalRepeat<>( + groupingSets, + (List) ImmutableList.of(a, b, c), + new SlotReference("grouping_id", IntegerType.INSTANCE), + RepeatType.GROUPING_SETS, + emptyRelation); + LogicalRepeat repeatCube = new LogicalRepeat<>( + groupingSets, + (List) ImmutableList.of(a, b, c), + new SlotReference("grouping_id", IntegerType.INSTANCE), + RepeatType.CUBE, + emptyRelation); + + // Case 1: Session variable decomposeRepeatShuffleIndexInMaxGroup = 0, should return third expr + connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = 2; + @SuppressWarnings("unchecked") + Optional> result2 = (Optional>) method.invoke( + rule, repeatRollup, 0, maxGroupByList, connectContext); + Assertions.assertTrue(result2.isPresent()); + Assertions.assertEquals(1, result2.get().size()); + Assertions.assertEquals(c, result2.get().get(0)); + + // Case 2: Session variable = -1 (default), fall through to repeat-type logic (may be empty if no stats) + connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1; + @SuppressWarnings("unchecked") + Optional> resultDefault = (Optional>) method.invoke( + rule, repeatRollup, 0, maxGroupByList, connectContext); + // With no column stats, chooseByRollupPrefixThenNdv typically returns empty + Assertions.assertEquals(resultDefault, Optional.empty()); + + // Case 3: Session variable out of range (>= size), should not use index, fall through + connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = 10; + @SuppressWarnings("unchecked") + Optional> resultOutOfRange = (Optional>) method.invoke( + rule, repeatRollup, 0, maxGroupByList, connectContext); + Assertions.assertEquals(resultOutOfRange, Optional.empty()); + + // Case 4: RepeatType GROUPING_SETS and CUBE (smoke test, result depends on stats) + connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1; + @SuppressWarnings("unchecked") + Optional> resultGs = (Optional>) method.invoke( + rule, repeatGroupingSets, 0, maxGroupByList, connectContext); + Assertions.assertEquals(resultGs, Optional.empty()); + + // Case 5: RepeatType GROUPING_SETS and CUBE (smoke test, result depends on stats) + connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1; + @SuppressWarnings("unchecked") + Optional> resultCb = (Optional>) method.invoke( + rule, repeatCube, 0, maxGroupByList, connectContext); + Assertions.assertEquals(resultCb, Optional.empty()); + + // Restore default + connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1; + } + + /** Helper: build Statistics with column ndv for given expressions. */ + private static Statistics statsWithNdv(Map exprToNdv) { + Map map = new HashMap<>(); + for (Map.Entry e : exprToNdv.entrySet()) { + ColumnStatistic col = new ColumnStatisticBuilder(1) + .setNdv(e.getValue()) + .setAvgSizeByte(4) + .setNumNulls(0) + .setMinValue(0) + .setMaxValue(100) + .setIsUnknown(false) + .setUpdatedTime("") + .build(); + map.put(e.getKey(), col); + } + return new Statistics(100, map); + } + + @Test + public void testChooseByAppearanceThenNdv() throws Exception { + Method method = rule.getClass().getDeclaredMethod("chooseByAppearanceThenNdv", + List.class, int.class, List.class, Statistics.class, int.class); + method.setAccessible(true); + + SlotReference a = new SlotReference("a", IntegerType.INSTANCE); + SlotReference b = new SlotReference("b", IntegerType.INSTANCE); + SlotReference c = new SlotReference("c", IntegerType.INSTANCE); + List candidates = ImmutableList.of(a, b, c); + + // grouping sets: index 0 = max (a,b,c), index 1 = (a,b), index 2 = (a) + // non-max: (a,b) and (a). a appears 2, b appears 1, c appears 1. + // countToCandidate: 1->[b,c], 2->[a]. TreeMap iterates 1 then 2. + // For count 1: chooseByNdv([b,c], stats, total). Need ndv > total to return. b:60, c:80, total=50 -> max ndv 80>50 -> return c. + List> groupingSets = ImmutableList.of( + ImmutableList.of(a, b, c), + ImmutableList.of(a, c), + ImmutableList.of(c) + ); + + Map exprToNdv = new HashMap<>(); + exprToNdv.put(a, 40.0); + exprToNdv.put(b, 60.0); + exprToNdv.put(c, 50.0); + Statistics stats = statsWithNdv(exprToNdv); + + @SuppressWarnings("unchecked") + Optional chosen = (Optional) method.invoke( + rule, groupingSets, -1, candidates, stats, 45); + Assertions.assertTrue(chosen.isPresent()); + Assertions.assertEquals(c, chosen.get()); + + // When no candidate has ndv > totalInstanceNum, return empty + @SuppressWarnings("unchecked") + Optional empty = (Optional) method.invoke( + rule, groupingSets, -1, candidates, stats, 1000); + Assertions.assertFalse(empty.isPresent()); + + @SuppressWarnings("unchecked") + Optional chosen2 = (Optional) method.invoke( + rule, groupingSets, -1, candidates, stats, 50); + Assertions.assertTrue(chosen2.isPresent()); + Assertions.assertEquals(b, chosen2.get()); + + // inputStats null -> chooseByNdv returns empty for every group -> empty + @SuppressWarnings("unchecked") + Optional emptyNullStats = (Optional) method.invoke( + rule, groupingSets, -1, candidates, null, 50); + Assertions.assertFalse(emptyNullStats.isPresent()); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java index 8bae1713fe1b51..44f5aa8e6bd554 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushDownFilterThroughAggregationTest.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; import org.apache.doris.nereids.trees.expressions.functions.scalar.If; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; @@ -179,7 +180,7 @@ public void pushDownPredicateGroupWithRepeatTest() { Slot name = scan.getOutput().get(2); LogicalRepeat repeatPlan = new LogicalRepeat<>( ImmutableList.of(ImmutableList.of(id, gender), ImmutableList.of(id)), - ImmutableList.of(id, gender, name), scan); + ImmutableList.of(id, gender, name), RepeatType.GROUPING_SETS, scan); NamedExpression nameMax = new Alias(new Max(name), "nameMax"); final Expression filterPredicateId = new GreaterThan(id, Literal.of(1)); @@ -206,7 +207,7 @@ public void pushDownPredicateGroupWithRepeatTest() { repeatPlan = new LogicalRepeat<>( ImmutableList.of(ImmutableList.of(id, gender), ImmutableList.of(gender)), - ImmutableList.of(id, gender, name), scan); + ImmutableList.of(id, gender, name), RepeatType.GROUPING_SETS, scan); plan = new LogicalPlanBuilder(repeatPlan) .aggGroupUsingIndexAndSourceRepeat(ImmutableList.of(0, 1), ImmutableList.of( id, gender, nameMax), Optional.of(repeatPlan)) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java index 4ba8edfe34cd35..80842f6271ba54 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopierTest.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; @@ -62,6 +63,7 @@ public void testDeepCopyAggregateWithSourceRepeat() { groupingSets, scan.getOutput().stream().map(NamedExpression.class::cast).collect(Collectors.toList()), groupingId, + RepeatType.GROUPING_SETS, scan ); List groupByExprs = repeat.getOutput().subList(0, 1).stream() diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/algebra/RepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/algebra/RepeatTest.java new file mode 100644 index 00000000000000..864fcc3e21d107 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/algebra/RepeatTest.java @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans.algebra; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingId; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.Repeat.RepeatType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Unit tests for {@link Repeat} interface default methods: + * toShapes, indexesOfOutput, getGroupingSetsIndexesInOutput, computeRepeatSlotIdList. + */ +public class RepeatTest { + + private LogicalOlapScan scan; + private Slot id; + private Slot gender; + private Slot name; + private Slot age; + + @BeforeEach + public void setUp() { + scan = new LogicalOlapScan(PlanConstructor.getNextRelationId(), PlanConstructor.student, ImmutableList.of("db")); + id = scan.getOutput().get(0); + gender = scan.getOutput().get(1); + name = scan.getOutput().get(2); + age = scan.getOutput().get(3); + } + + @Test + public void testToShapes() { + // grouping sets: (id, name), (id), () + // flatten = [id, name], shapes: [false,false], [false,true], [true,true] + List> groupingSets = ImmutableList.of( + ImmutableList.of(id, name), + ImmutableList.of(id), + ImmutableList.of() + ); + Alias alias = new Alias(new Sum(name), "sum(name)"); + Repeat repeat = new LogicalRepeat<>( + groupingSets, + ImmutableList.of(id, name, alias), + RepeatType.GROUPING_SETS, + scan + ); + + Repeat.GroupingSetShapes shapes = repeat.toShapes(); + + Assertions.assertEquals(2, shapes.flattenGroupingSetExpression.size()); + Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id)); + Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name)); + Assertions.assertEquals(3, shapes.shapes.size()); + + // (id, name) -> [false, false] + Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0)); + Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(1)); + Assertions.assertEquals(0L, shapes.shapes.get(0).computeLongValue()); + + // (id) -> [false, true] (id in set, name not) + Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0)); + Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(1)); + Assertions.assertEquals(1L, shapes.shapes.get(1).computeLongValue()); + + // () -> [true, true] + Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0)); + Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1)); + Assertions.assertEquals(3L, shapes.shapes.get(2).computeLongValue()); + } + + @Test + public void testToShapesWithGroupingFunction() { + // grouping(id) adds id to flatten if not present; single set (name) has flatten [name, id] + List> groupingSets = ImmutableList.of( + ImmutableList.of(name), ImmutableList.of(name, id), ImmutableList.of()); + Alias groupingAlias = new Alias(new GroupingId(gender, age), "grouping_id(id)"); + Repeat repeat = new LogicalRepeat<>( + groupingSets, + ImmutableList.of(name, groupingAlias), + RepeatType.GROUPING_SETS, + scan + ); + + Repeat.GroupingSetShapes shapes = repeat.toShapes(); + + // flatten = [name] from getGroupBy + [id] from grouping function arg + Assertions.assertEquals(4, shapes.flattenGroupingSetExpression.size()); + Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(name)); + Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(id)); + Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(gender)); + Assertions.assertTrue(shapes.flattenGroupingSetExpression.contains(age)); + + Assertions.assertEquals(3, shapes.shapes.size()); + // (name) -> name not erased, id,gender,age erased + Assertions.assertFalse(shapes.shapes.get(0).shouldBeErasedToNull(0)); + Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(1)); + Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(2)); + Assertions.assertTrue(shapes.shapes.get(0).shouldBeErasedToNull(3)); + // (name, id) -> name,id not erased, gender and age erased + Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(0)); + Assertions.assertFalse(shapes.shapes.get(1).shouldBeErasedToNull(1)); + Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(2)); + Assertions.assertTrue(shapes.shapes.get(1).shouldBeErasedToNull(3)); + //() -> all erased + Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(0)); + Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(1)); + Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(2)); + Assertions.assertTrue(shapes.shapes.get(2).shouldBeErasedToNull(3)); + } + + @Test + public void testIndexesOfOutput() { + List outputSlots = ImmutableList.of(id, gender, name, age); + Map indexes = Repeat.indexesOfOutput(outputSlots); + Assertions.assertEquals(4, indexes.size()); + Assertions.assertEquals(0, indexes.get(id)); + Assertions.assertEquals(1, indexes.get(gender)); + Assertions.assertEquals(2, indexes.get(name)); + Assertions.assertEquals(3, indexes.get(age)); + } + + @Test + public void testGetGroupingSetsIndexesInOutput() { + // groupingSets=((name, id), (id), (gender)), output=[id, name, gender] + // expect:((1,0),(0),(2)) + List> groupingSets = ImmutableList.of( + ImmutableList.of(name, id), + ImmutableList.of(id), + ImmutableList.of(gender) + ); + Alias groupingId = new Alias(new GroupingId(id, name)); + Repeat repeat = new LogicalRepeat<>( + groupingSets, + ImmutableList.of(id, name, gender, groupingId), + RepeatType.GROUPING_SETS, + scan + ); + List outputSlots = ImmutableList.of(id, name, gender, groupingId.toSlot()); + + List> result = repeat.getGroupingSetsIndexesInOutput(outputSlots); + + Assertions.assertEquals(3, result.size()); + // (name, id) -> indexes {1, 0} + Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(1, 0)), result.get(0)); + // (id) -> index {0} + Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(0)), result.get(1)); + // (gender) -> index {2} + Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(2)), result.get(2)); + } + + @Test + public void testComputeRepeatSlotIdList() { + // groupingSets=((name, id), (id)), output=[id, name], slotIdList=[3, 4] (id->3, name->4) + List> groupingSets = ImmutableList.of( + ImmutableList.of(name, id), + ImmutableList.of(id) + ); + Repeat repeat = new LogicalRepeat<>( + groupingSets, + ImmutableList.of(id, name), + RepeatType.GROUPING_SETS, + scan + ); + List outputSlots = ImmutableList.of(id, name); + List slotIdList = ImmutableList.of(3, 4); + + List> result = repeat.computeRepeatSlotIdList(slotIdList, outputSlots); + + Assertions.assertEquals(2, result.size()); + // (name, id) -> indexes {1,0} -> slot ids {4, 3} + Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(4, 3)), result.get(0)); + // (id) -> index {0} -> slot id {3} + Assertions.assertEquals(Sets.newLinkedHashSet(ImmutableList.of(3)), result.get(1)); + } +} diff --git a/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out b/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out index e6516a0d47c1cd..f8ab9595435c91 100644 --- a/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out +++ b/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out @@ -37,7 +37,6 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) 100000 100000 100000 -100000 -- !sql_2_shape -- PhysicalCteAnchor ( cteId=CTEId#0 ) @@ -60,11 +59,9 @@ PhysicalCteAnchor ( cteId=CTEId#0 ) -- !sql_2_result -- \N ALL 1 6 \N \N \N \N ALL 1 6 \N \N \N -2020-01-02T00:00 ALL 1 6 \N 2020-01-02T00:00 \N -2020-01-02T00:00 ALL 1 6 \N 2020-01-02T00:00 \N -2020-01-03T00:00 ALL 1 6 \N 2020-01-03T00:00 \N -2020-01-03T00:00 ALL 1 6 \N 2020-01-03T00:00 \N -2020-01-04T00:00 ALL 1 6 \N 2020-01-04T00:00 \N -2020-01-04T00:00 ALL 1 6 \N 2020-01-04T00:00 \N +2020-01-04T00:00 ALL 1 6 \N \N a +2020-01-04T00:00 ALL 1 6 \N \N a +2020-01-04T00:00 ALL 1 6 \N \N b +2020-01-04T00:00 ALL 1 6 \N \N b 2020-01-04T00:00 ALL 1 7 \N \N \N diff --git a/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out b/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out index 919738109c1efd..d5245f25084e2d 100644 --- a/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out +++ b/regression-test/data/nereids_rules_p0/decompose_repeat/decompose_repeat.out @@ -369,3 +369,187 @@ 1 3 2 \N 2 0 1 3 2 2 2 0 +-- !grouping_only_in_max -- +\N \N \N 1 +1 \N \N 1 +1 2 \N 1 +1 2 1 0 +1 2 3 0 +1 3 \N 1 +1 3 2 0 + +-- !grouping_id_only_in_max_c_d -- +\N \N \N 15 +1 \N \N 7 +1 2 \N 3 +1 2 1 0 +1 2 3 0 +1 2 3 0 +1 3 \N 3 +1 3 2 0 + +-- !grouping_id_only_in_max_d -- +\N \N \N 15 +1 \N \N 7 +1 2 1 0 +1 2 1 1 +1 2 3 0 +1 2 3 0 +1 2 3 1 +1 3 2 0 +1 3 2 1 + +-- !multi_grouping_func -- +\N \N \N \N 7 7 7 3 +1 \N \N \N 3 6 5 0 +1 2 1 \N 0 0 0 0 +1 2 1 1 0 0 0 0 +1 2 3 \N 0 0 0 0 +1 2 3 3 0 0 0 0 +1 2 3 4 0 0 0 0 +1 3 2 \N 0 0 0 0 +1 3 2 2 0 0 0 0 + +-- !grouping_partial_only_in_max -- +\N \N \N \N 7 +1 2 \N \N 3 +1 2 1 \N 1 +1 2 1 1 0 +1 2 3 \N 1 +1 2 3 3 0 +1 2 3 4 0 +1 3 \N \N 3 +1 3 2 \N 1 +1 3 2 2 0 + +-- !mixed_grouping_func_1 -- +\N \N \N \N 1 7 +1 \N \N \N 0 7 +1 2 1 \N 0 1 +1 2 1 1 0 0 +1 2 3 \N 0 1 +1 2 3 3 0 0 +1 2 3 4 0 0 +1 3 2 \N 0 1 +1 3 2 2 0 0 + +-- !grouping_all_in_other -- +\N \N \N \N 3 +1 \N \N \N 1 +1 2 \N \N 0 +1 2 1 \N 0 +1 2 1 1 0 +1 2 3 \N 0 +1 2 3 3 0 +1 2 3 4 0 +1 3 \N \N 0 +1 3 2 \N 0 +1 3 2 2 0 + +-- !grouping_dup_col -- +\N \N \N \N 31 +1 \N \N \N 10 +1 2 1 \N 0 +1 2 1 1 0 +1 2 3 \N 0 +1 2 3 3 0 +1 2 3 4 0 +1 3 2 \N 0 +1 3 2 2 0 + +-- !mixed_grouping_both -- +\N \N \N \N 1 1 7 3 +1 \N \N \N 0 1 3 3 +1 2 1 \N 0 0 0 1 +1 2 1 1 0 0 0 0 +1 2 3 \N 0 0 0 1 +1 2 3 3 0 0 0 0 +1 2 3 4 0 0 0 0 +1 3 2 \N 0 0 0 1 +1 3 2 2 0 0 0 0 + +-- !grouping_different_pos -- +\N \N \N \N 3 +1 \N 1 \N 3 +1 \N 2 \N 3 +1 \N 3 \N 3 +1 2 \N \N 1 +1 2 1 1 0 +1 2 3 3 0 +1 2 3 4 0 +1 3 \N \N 1 +1 3 2 2 0 + +-- !grouping_nested_case -- +\N \N \N \N 0 +1 \N \N \N 0 +1 2 1 \N 0 +1 2 1 1 1 +1 2 3 \N 0 +1 2 3 3 1 +1 2 3 4 1 +1 3 2 \N 0 +1 3 2 2 1 + +-- !grouping_mixed_params_1 -- +\N \N \N \N 7 +1 \N \N \N 3 +1 2 \N \N 1 +1 2 1 \N 1 +1 2 1 1 0 +1 2 3 \N 1 +1 2 3 3 0 +1 2 3 4 0 +1 3 \N \N 1 +1 3 2 \N 1 +1 3 2 2 0 + +-- !grouping_single_param_multi -- +\N \N \N \N 1 +1 \N 1 \N 0 +1 \N 2 \N 0 +1 \N 3 \N 0 +1 2 1 \N 0 +1 2 1 1 0 +1 2 3 \N 0 +1 2 3 3 0 +1 2 3 4 0 +1 3 2 \N 0 +1 3 2 2 0 + +-- !grouping_multi_combinations -- +\N \N \N \N 1 3 7 15 +1 \N \N \N 0 1 3 7 +1 2 \N \N 0 0 1 3 +1 2 1 \N 0 0 0 1 +1 2 1 1 0 0 0 0 +1 2 3 \N 0 0 0 1 +1 2 3 3 0 0 0 0 +1 2 3 4 0 0 0 0 +1 3 \N \N 0 0 1 3 +1 3 2 \N 0 0 0 1 +1 3 2 2 0 0 0 0 + +-- !grouping_max_not_first -- +\N \N \N \N 3 +1 2 \N \N 3 +1 2 1 \N 1 +1 2 1 1 0 +1 2 3 \N 1 +1 2 3 3 0 +1 2 3 4 0 +1 3 \N \N 3 +1 3 2 \N 1 +1 3 2 2 0 + +-- !grouping_with_agg -- +\N \N \N \N 10 7 +1 \N \N \N 10 3 +1 2 1 \N 1 0 +1 2 1 1 1 0 +1 2 3 \N 7 0 +1 2 3 3 3 0 +1 2 3 4 4 0 +1 3 2 \N 2 0 +1 3 2 2 2 0 + diff --git a/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy b/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy index 338517afbc4f46..ea3e59dede7a29 100644 --- a/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy +++ b/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy @@ -23,11 +23,12 @@ suite("decompose_repeat") { order_qt_sum "select a,b,c,sum(d) from t1 group by rollup(a,b,c);" order_qt_agg_func_gby_key_same_col "select a,b,c,d,sum(d) from t1 group by rollup(a,b,c,d);" order_qt_multi_agg_func "select a,b,c,sum(d),sum(c),max(a) from t1 group by rollup(a,b,c,d);" - order_qt_nest_rewrite """ - select a,b,c,c1 from ( - select a,b,c,d,sum(d) c1 from t1 group by grouping sets((a,b,c),(a,b,c,d),(a),(a,b,c,c)) - ) t group by rollup(a,b,c,c1); - """ + // maybe this problem:DORIS-24075 +// order_qt_nest_rewrite """ +// select a,b,c,c1 from ( +// select a,b,c,d,sum(d) c1 from t1 group by grouping sets((a,b,c),(a,b,c,d),(a),(a,b,c,c)) +// ) t group by rollup(a,b,c,c1); +// """ order_qt_upper_ref """ select c1+10,a,b,c from (select a,b,c,sum(d) c1 from t1 group by rollup(a,b,c)) t group by c1+10,a,b,c; """ @@ -71,4 +72,49 @@ suite("decompose_repeat") { order_qt_cube "select a,b,c,d,sum(d),grouping_id(a) from t1 group by cube(a,b,c,d)" order_qt_cube_add "select a,b,c,d,sum(d)+100+grouping_id(a) from t1 group by cube(a,b,c,d);" order_qt_cube_sum_parm_add "select a,b,c,d,sum(a+1),grouping_id(a) from t1 group by cube(a,b,c,d);" + + // grouping scalar functions add more test + order_qt_grouping_only_in_max "select a,b,c, grouping(c) from t1 group by grouping sets((a,b,c),(a,b),(a),());" + order_qt_grouping_id_only_in_max_c_d "select a,b,c, grouping_id(a,b,c,d) from t1 group by grouping sets((a,b,c,d),(a,b),(a),());" + order_qt_grouping_id_only_in_max_d "select a,b,c, grouping_id(a,b,c,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + order_qt_multi_grouping_func "select a,b,c,d, grouping_id(a,b,c), grouping_id(c,b,a), grouping_id(c,a,b), grouping_id(a,a) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + + // more test cases for grouping scalar function bug(added by ai) + // Test case: grouping function with partial parameters only in max group + order_qt_grouping_partial_only_in_max "select a,b,c,d, grouping_id(a,c,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a,b),());" + // Test case: multiple grouping functions, some can optimize and some cannot + order_qt_mixed_grouping_func_1 "select a,b,c,d, grouping(a), grouping_id(b,c,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + // Test case: grouping function with all parameters exist in other groups (should optimize) + order_qt_grouping_all_in_other "select a,b,c,d, grouping_id(a,b) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a,b),(a),());" + // Test case: grouping function with same column repeated + order_qt_grouping_dup_col "select a,b,c,d, grouping_id(a,b,a,c,a) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + // Test case: both grouping and grouping_id with different parameters + order_qt_mixed_grouping_both "select a,b,c,d, grouping(a), grouping(b), grouping_id(a,b,c), grouping_id(c,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + // Test case: grouping function with columns in different positions + order_qt_grouping_different_pos "select a,b,c,d, grouping_id(b,d) from t1 group by grouping sets((a,b,c,d),(a,b),(a,c),());" + // Test case: nested case with grouping functions that reference only-max columns + order_qt_grouping_nested_case "select a,b,c,d, case when grouping(d) = 1 then 0 else 1 end from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + // Test case: grouping function parameter mix - one in max only, others in all groups + order_qt_grouping_mixed_params_1 "select a,b,c,d, grouping_id(a,b,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a,b),(a),());" + // Test case: grouping function with single parameter that exists in multiple groups + order_qt_grouping_single_param_multi "select a,b,c,d, grouping(c) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a,c),());" + // Test case: multiple grouping_id functions with different parameter combinations + order_qt_grouping_multi_combinations "select a,b,c,d, grouping_id(a), grouping_id(a,b), grouping_id(a,b,c), grouping_id(a,b,c,d) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a,b),(a),());" + // Test case: grouping function where max group is not first + order_qt_grouping_max_not_first "select a,b,c,d, grouping_id(c,d) from t1 group by grouping sets((a,b),(a,b,c),(a,b,c,d),());" + // Test case: complex case with aggregation function and grouping function + order_qt_grouping_with_agg "select a,b,c,d, sum(d), grouping_id(a,b,c) from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());" + + // test empty grouping set + sql "select 2 from t_repeat_pick_shuffle_key group by grouping sets((),(),(),());" + multi_sql """drop table if exists t_repeat_pick_shuffle_key; + create table t_repeat_pick_shuffle_key(a int, b int, c int, d int); + alter table t_repeat_pick_shuffle_key modify column a set stats ('row_count'='300000', 'ndv'='10', 'num_nulls'='0', 'min_value'='1', 'max_value'='300000', 'data_size'='2400000'); + alter table t_repeat_pick_shuffle_key modify column b set stats ('row_count'='300000', 'ndv'='100', 'num_nulls'='0', 'min_value'='1', 'max_value'='300000', 'data_size'='2400000'); + alter table t_repeat_pick_shuffle_key modify column c set stats ('row_count'='300000', 'ndv'='1000', 'num_nulls'='0', 'min_value'='1', 'max_value'='300000', 'data_size'='2400000'); + alter table t_repeat_pick_shuffle_key modify column d set stats ('row_count'='300000', 'ndv'='10000', 'num_nulls'='0', 'min_value'='1', 'max_value'='300000', 'data_size'='2400000');""" + sql "select a,b,c,d from t_repeat_pick_shuffle_key group by rollup(a,b,c,d);" + sql "select a,b,c,d from t_repeat_pick_shuffle_key group by cube(a,b,c,d);" + sql "select a,b,c,d from t_repeat_pick_shuffle_key group by grouping sets((a,b,c,d),(b,c,d),(c),(c,a));" + } \ No newline at end of file