Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,12 @@ private boolean shouldUseMultiDistinct(LogicalAggregate<? extends Plan> agg) {
}
}
} else {
if (AggregateUtils.hasUnknownStatistics(agg.getGroupByExpressions(), childStats)) {
if (agg.hasSkewHint()) {
return false;
}
if (AggregateUtils.hasUnknownStatistics(agg.getGroupByExpressions(), childStats)) {
return true;
}
// The joint ndv of Group by key is high, so multi_distinct is not selected;
if (aggStats.getRowCount() >= row * AggregateUtils.LOW_CARDINALITY_THRESHOLD) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.Statistics;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;

Expand Down Expand Up @@ -89,7 +90,8 @@ public List<Rule> buildRules() {
);
}

private boolean shouldUseMultiDistinct(LogicalAggregate<? extends Plan> aggregate) {
@VisibleForTesting
boolean shouldUseMultiDistinct(LogicalAggregate<? extends Plan> aggregate) {
// count(distinct a,b) cannot use multi_distinct
if (AggregateUtils.containsCountDistinctMultiExpr(aggregate)) {
return false;
Expand All @@ -111,7 +113,7 @@ private boolean shouldUseMultiDistinct(LogicalAggregate<? extends Plan> aggregat
// has unknown statistics, split to bottom and top agg
if (AggregateUtils.hasUnknownStatistics(aggregate.getGroupByExpressions(), aggChildStats)
|| AggregateUtils.hasUnknownStatistics(dstArgs, aggChildStats)) {
return false;
return true;
}

double gbyNdv = aggStats.getRowCount();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,17 @@ default boolean canSkewRewrite() {
&& !getGroupByExpressions().isEmpty()
&& !(new HashSet<>(getGroupByExpressions()).containsAll(distinctArguments));
}

/**
* hasSkewHint
* @return true if there is at least one skew hint
*/
default boolean hasSkewHint() {
for (AggregateFunction aggFunc : getAggregateFunctions()) {
if (aggFunc.isSkew()) {
return true;
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,18 @@

import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctGroupConcat;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.utframe.TestWithFeService;

import mockit.Mock;
import mockit.MockUp;
import org.junit.jupiter.api.Test;

public class DistinctAggregateRewriterTest extends TestWithFeService implements MemoPatternMatchSupported {
Expand All @@ -39,8 +44,18 @@ protected void runBeforeAll() throws Exception {
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
}

private void applyMock() {
new MockUp<DistinctAggregateRewriter>() {
@Mock
boolean shouldUseMultiDistinct(LogicalAggregate<? extends Plan> aggregate) {
return false;
}
};
}

@Test
void testSplitSingleDistinctAgg() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select b, count(distinct a) from test.distinct_agg_split_t group by b")
.rewrite()
Expand All @@ -59,6 +74,7 @@ void testSplitSingleDistinctAgg() {

@Test
void testSplitSingleDistinctAggOtherFunctionCount() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select b, count(distinct a), count(a) from test.distinct_agg_split_t group by b")
.rewrite()
Expand All @@ -77,6 +93,7 @@ void testSplitSingleDistinctAggOtherFunctionCount() {

@Test
void testSplitSingleDistinctWithOtherAgg() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select b, count(distinct a), sum(c) from test.distinct_agg_split_t group by b")
.rewrite()
Expand All @@ -93,6 +110,7 @@ void testSplitSingleDistinctWithOtherAgg() {

@Test
void testNotSplitWhenNoGroupBy() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select count(distinct a) from test.distinct_agg_split_t")
.rewrite()
Expand All @@ -102,6 +120,7 @@ void testNotSplitWhenNoGroupBy() {

@Test
void testSplitWhenNoGroupByHasGroupConcatDistinctOrderBy() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select group_concat(distinct a, '' order by b) from test.distinct_agg_split_t")
.rewrite()
Expand All @@ -113,6 +132,7 @@ void testSplitWhenNoGroupByHasGroupConcatDistinctOrderBy() {

@Test
void testSplitWhenNoGroupByHasGroupConcatDistinct() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select group_concat(distinct a, '') from test.distinct_agg_split_t")
.rewrite()
Expand All @@ -124,6 +144,7 @@ void testSplitWhenNoGroupByHasGroupConcatDistinct() {

@Test
void testMultiExprDistinct() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select b, sum(a), count(distinct a,c) from test.distinct_agg_split_t group by b")
.rewrite()
Expand All @@ -142,6 +163,7 @@ void testMultiExprDistinct() {

@Test
void testNotSplitWhenNoDistinct() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select b, sum(a), count(c) from test.distinct_agg_split_t group by b")
.rewrite()
Expand All @@ -151,6 +173,7 @@ void testNotSplitWhenNoDistinct() {

@Test
void testSplitWithComplexExpression() {
applyMock();
PlanChecker.from(connectContext)
.analyze("select b, count(distinct a + 1) from test.distinct_agg_split_t group by b")
.rewrite()
Expand All @@ -161,4 +184,19 @@ void testSplitWithComplexExpression() {
).when(agg -> agg.getGroupByExpressions().size() == 1
&& agg.getGroupByExpressions().get(0).toSql().equals("b")));
}

@Test
void testMultiDistinct() {
connectContext.getSessionVariable().setAggPhase(2);
PlanChecker.from(connectContext)
.analyze("select b, count(distinct a), sum(c) from test.distinct_agg_split_t group by b")
.rewrite()
.printlnTree()
.matches(
logicalAggregate().when(agg -> agg.getGroupByExpressions().size() == 1
&& agg.getGroupByExpressions().get(0).toSql().equals("b")
&& agg.getAggregateFunctions().stream().noneMatch(AggregateFunction::isDistinct)
&& agg.getAggregateFunctions().stream().anyMatch(f -> f instanceof MultiDistinctCount)
));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -199,4 +199,19 @@ void countMultiColumnsWithGby() {
);
});
}

@Test
void multiSumWithGby() {
String sql = "select sum(distinct b), sum(distinct a) from test_distinct_multi group by c";
PlanChecker.from(connectContext).checkExplain(sql, planner -> {
Plan plan = planner.getOptimizedPlan();
MatchingUtils.assertMatches(plan,
physicalResultSink(
physicalDistribute(
physicalProject(
physicalHashAggregate(
physicalDistribute(
physicalHashAggregate(any())))))));
});
}
}
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !avg_shape --
PhysicalCteAnchor ( cteId=CTEId#0 )
--PhysicalCteProducer ( cteId=CTEId#0 )
----PhysicalOlapScan[test_adjust_nullable_t]
--PhysicalResultSink
----PhysicalProject[(cast(sum(DISTINCT b) as DOUBLE) / cast(count(DISTINCT b) as DOUBLE)) AS `AVG(distinct b)`, non_nullable((cast(sum(DISTINCT a) as DOUBLE) / cast(count(DISTINCT a) as DOUBLE))) AS `AVG(distinct a)`]
------hashJoin[INNER_JOIN colocated] hashCondition=((c <=> .c)) otherCondition=()
--------PhysicalProject[c AS `c`, count(DISTINCT a) AS `count(DISTINCT a)`, sum(DISTINCT a) AS `sum(DISTINCT a)`]
----------hashAgg[DISTINCT_GLOBAL]
------------hashAgg[GLOBAL]
--------------PhysicalDistribute[DistributionSpecHash]
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
--------PhysicalProject[.c, count(DISTINCT b) AS `count(DISTINCT b)`, sum(DISTINCT b) AS `sum(DISTINCT b)`]
----------hashAgg[DISTINCT_GLOBAL]
------------hashAgg[GLOBAL]
--------------PhysicalDistribute[DistributionSpecHash]
----------------PhysicalCteConsumer ( cteId=CTEId#0 )
PhysicalResultSink
--PhysicalProject[(cast(sum(DISTINCT b) as DOUBLE) / cast(count(DISTINCT b) as DOUBLE)) AS `AVG(distinct b)`, non_nullable((cast(sum(DISTINCT a) as DOUBLE) / cast(count(DISTINCT a) as DOUBLE))) AS `AVG(distinct a)`]
----hashAgg[GLOBAL]
------PhysicalDistribute[DistributionSpecHash]
--------hashAgg[LOCAL]
----------PhysicalOlapScan[test_adjust_nullable_t]

Original file line number Diff line number Diff line change
Expand Up @@ -353,8 +353,7 @@ PhysicalResultSink
-- !shape_hint_other_agg_func --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[GLOBAL]
------PhysicalOlapScan[test_skew_hint]
----PhysicalOlapScan[test_skew_hint]

-- !shape_hint_other_agg_func_expr --
PhysicalResultSink
Expand All @@ -367,9 +366,8 @@ PhysicalResultSink
-- !shape_hint_same_column_with_group_by --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalOlapScan[test_skew_hint]
----hashAgg[LOCAL]
------PhysicalOlapScan[test_skew_hint]

-- !shape_hint_same_column_with_group_by_expr --
PhysicalResultSink
Expand All @@ -391,10 +389,9 @@ PhysicalResultSink
-- !shape_hint_other_agg_func_grouping --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[GLOBAL]
------hashAgg[LOCAL]
--------PhysicalRepeat
----------PhysicalOlapScan[test_skew_hint]
----hashAgg[LOCAL]
------PhysicalRepeat
--------PhysicalOlapScan[test_skew_hint]

-- !shape_hint_other_agg_func_expr_grouping --
PhysicalResultSink
Expand Down Expand Up @@ -567,12 +564,10 @@ PhysicalResultSink
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[LOCAL]
------hashAgg[GLOBAL]
--------PhysicalOlapScan[test_skew_hint]
------PhysicalOlapScan[test_skew_hint]

-- !shape_not_rewrite --
PhysicalResultSink
--hashAgg[GLOBAL]
----hashAgg[GLOBAL]
------PhysicalOlapScan[test_skew_hint]
----PhysicalOlapScan[test_skew_hint]

Loading