diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index 80fb39866437bb..fd5ae6931d96bf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -37,7 +37,6 @@ import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue; -import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.plans.Plan; @@ -270,15 +269,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional aggOutput = normalizedAggOutputBuilder.build(); - ImmutableList.Builder newAggOutputBuilder - = ImmutableList.builderWithExpectedSize(aggOutput.size()); - for (NamedExpression output : aggOutput) { - Expression rewrittenExpr = output.rewriteDownShortCircuit( - e -> e instanceof MultiDistinction ? ((MultiDistinction) e).withMustUseMultiDistinctAgg(true) : e); - newAggOutputBuilder.add((NamedExpression) rewrittenExpr); - } - ImmutableList normalizedAggOutput = newAggOutputBuilder.build(); + ImmutableList normalizedAggOutput = normalizedAggOutputBuilder.build(); // create upper projects by normalize all output exprs in old LogicalAggregate // In aggregateOutput, the expressions inside the agg function can be rewritten @@ -313,7 +304,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional (NamedExpression) ExpressionUtils.replace(e, replaceMap)) @@ -328,10 +319,10 @@ private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional newAggregate = - aggregate.withNormalized(normalizedGroupExprs, newAggOutputBuilder.build(), bottomPlan); + aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutputBuilder.build(), bottomPlan); ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx); LogicalProject project = eliminateGroupByConstant(groupByExprContext, rewriteContext, normalizedGroupExprs, normalizedAggOutput, bottomProjects, aggregate, upperProjects, newAggregate); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java index e34d56d383c6da..1f118066d8dce3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java @@ -84,7 +84,7 @@ public List buildRules() { .toRule(RuleType.DISTINCT_AGGREGATE_SPLIT), logicalAggregate() .when(agg -> agg.getGroupByExpressions().isEmpty() - && agg.mustUseMultiDistinctAgg()) + && agg.mustUseMultiDistinctAgg() && !AggregateUtils.containsCountDistinctMultiExpr(agg)) .then(this::convertToMultiDistinct) .toRule(RuleType.PROCESS_SCALAR_AGG_MUST_USE_MULTI_DISTINCT) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java index cdafbef9528fb5..1dadd80d6dabab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctCount.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.analyzer.Unbound; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; @@ -32,6 +33,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import java.util.LinkedHashSet; import java.util.List; /** MultiDistinctCount */ @@ -40,7 +42,6 @@ public class MultiDistinctCount extends NotNullableAggregateFunction public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE_WITHOUT_INDEX) ); - private final boolean mustUseMultiDistinctAgg; // MultiDistinctCount is created in AggregateStrategies phase // can't change getSignatures to use type coercion rule to add a cast expr @@ -50,28 +51,29 @@ public MultiDistinctCount(Expression arg0, Expression... varArgs) { } public MultiDistinctCount(boolean distinct, Expression arg0, Expression... varArgs) { - this(false, false, ExpressionUtils.mergeArguments(arg0, varArgs)); + this(false, ExpressionUtils.mergeArguments(arg0, varArgs)); } - private MultiDistinctCount(boolean mustUseMultiDistinctAgg, boolean distinct, List children) { - super("multi_distinct_count", false, children + private MultiDistinctCount(boolean distinct, List children) { + super("multi_distinct_count", false, new LinkedHashSet<>(children) .stream() .map(arg -> !(arg instanceof Unbound) && arg.getDataType() instanceof DateLikeType ? new Cast(arg, BigIntType.INSTANCE) : arg) .collect(ImmutableList.toImmutableList())); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; + if (super.children().size() > 1) { + throw new AnalysisException("MultiDistinctCount's children size must be 1"); + } } /** constructor for withChildren and reuse signature */ - protected MultiDistinctCount(boolean mustUseMultiDistinctAgg, AggregateFunctionParams functionParams) { + protected MultiDistinctCount(AggregateFunctionParams functionParams) { super(functionParams); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; } @Override public MultiDistinctCount withDistinctAndChildren(boolean distinct, List children) { - Preconditions.checkArgument(!children.isEmpty()); - return new MultiDistinctCount(mustUseMultiDistinctAgg, getFunctionParams(false, children)); + Preconditions.checkArgument(children.size() == 1, "MultiDistinctCount's children size must be 1"); + return new MultiDistinctCount(getFunctionParams(false, children)); } @Override @@ -84,16 +86,6 @@ public List getSignatures() { return SIGNATURES; } - @Override - public boolean mustUseMultiDistinctAgg() { - return mustUseMultiDistinctAgg; - } - - @Override - public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) { - return new MultiDistinctCount(mustUseMultiDistinctAgg, getFunctionParams(children)); - } - @Override public Expression resultForEmptyInput() { return new BigIntLiteral(0); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java index f686b7727e1add..5fc78899a72b4b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctGroupConcat.java @@ -51,7 +51,6 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction .varArgs(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE_WITHOUT_INDEX) ); - private final boolean mustUseMultiDistinctAgg; private final int nonOrderArguments; /** @@ -64,26 +63,19 @@ public MultiDistinctGroupConcat(Expression arg, Expression... others) { /** * constructor with argument list. */ - public MultiDistinctGroupConcat(boolean alwaysNullable, List args) { - this(false, alwaysNullable, args); - } - private MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg, Expression... others) { this(alwaysNullable, ExpressionUtils.mergeArguments(arg, others)); } - private MultiDistinctGroupConcat(boolean mustUseMultiDistinctAgg, boolean alwaysNullable, List args) { + public MultiDistinctGroupConcat(boolean alwaysNullable, List args) { super("multi_distinct_group_concat", false, alwaysNullable, args); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; this.nonOrderArguments = findOrderExprIndex(children); } /** constructor for withChildren and reuse signature */ - protected MultiDistinctGroupConcat( - boolean mustUseMultiDistinctAgg, NullableAggregateFunctionParams functionParams) { + protected MultiDistinctGroupConcat(NullableAggregateFunctionParams functionParams) { super(functionParams); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; this.nonOrderArguments = findOrderExprIndex(children); } @@ -95,7 +87,7 @@ public boolean nullable() { @Override public MultiDistinctGroupConcat withAlwaysNullable(boolean alwaysNullable) { - return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, getAlwaysNullableFunctionParams(alwaysNullable)); + return new MultiDistinctGroupConcat(getAlwaysNullableFunctionParams(alwaysNullable)); } /** @@ -103,7 +95,7 @@ public MultiDistinctGroupConcat withAlwaysNullable(boolean alwaysNullable) { */ @Override public MultiDistinctGroupConcat withDistinctAndChildren(boolean distinct, List children) { - return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, getFunctionParams(false, children)); + return new MultiDistinctGroupConcat(getFunctionParams(false, children)); } @Override @@ -126,16 +118,6 @@ public List getSignatures() { } } - @Override - public boolean mustUseMultiDistinctAgg() { - return mustUseMultiDistinctAgg || children.stream().anyMatch(OrderExpression.class::isInstance); - } - - @Override - public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) { - return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children); - } - private int findOrderExprIndex(List children) { Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1"); boolean foundOrderExpr = false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java index 954e95a4383df0..ff69f4182d107c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum.java @@ -33,9 +33,6 @@ /** MultiDistinctSum */ public class MultiDistinctSum extends NullableAggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction { - - private final boolean mustUseMultiDistinctAgg; - public MultiDistinctSum(Expression arg0) { this(false, arg0); } @@ -45,19 +42,12 @@ public MultiDistinctSum(boolean distinct, Expression arg0) { } public MultiDistinctSum(boolean distinct, boolean alwaysNullable, Expression arg0) { - this(false, false, alwaysNullable, arg0); - } - - private MultiDistinctSum(boolean mustUseMultiDistinctAgg, boolean distinct, - boolean alwaysNullable, Expression arg0) { super("multi_distinct_sum", false, alwaysNullable, arg0); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; } /** constructor for withChildren and reuse signature */ - private MultiDistinctSum(boolean mustUseMultiDistinctAgg, NullableAggregateFunctionParams functionParams) { + private MultiDistinctSum(NullableAggregateFunctionParams functionParams) { super(functionParams); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; } @Override @@ -81,27 +71,17 @@ public FunctionSignature searchSignature(List signatures) { @Override public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) { - return new MultiDistinctSum(mustUseMultiDistinctAgg, getAlwaysNullableFunctionParams(alwaysNullable)); + return new MultiDistinctSum(getAlwaysNullableFunctionParams(alwaysNullable)); } @Override public MultiDistinctSum withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new MultiDistinctSum(mustUseMultiDistinctAgg, getFunctionParams(false, children)); + return new MultiDistinctSum(getFunctionParams(false, children)); } @Override public R accept(ExpressionVisitor visitor, C context) { return visitor.visitMultiDistinctSum(this, context); } - - @Override - public boolean mustUseMultiDistinctAgg() { - return mustUseMultiDistinctAgg; - } - - @Override - public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) { - return new MultiDistinctSum(mustUseMultiDistinctAgg, getFunctionParams(children)); - } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum0.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum0.java index 68af11bf5320ff..380132dce8d4ed 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum0.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinctSum0.java @@ -40,26 +40,17 @@ /** MultiDistinctSum0 */ public class MultiDistinctSum0 extends NotNullableAggregateFunction implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction { - - private final boolean mustUseMultiDistinctAgg; - public MultiDistinctSum0(Expression arg0) { this(false, arg0); } public MultiDistinctSum0(boolean distinct, Expression arg0) { - this(false, false, arg0); - } - - private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, boolean distinct, Expression arg0) { super("multi_distinct_sum0", false, arg0); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; } /** constructor for withChildren and reuse signature */ - private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, AggregateFunctionParams functionParams) { + public MultiDistinctSum0(AggregateFunctionParams functionParams) { super(functionParams); - this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg; } @Override @@ -84,7 +75,7 @@ public FunctionSignature searchSignature(List signatures) { @Override public MultiDistinctSum0 withDistinctAndChildren(boolean distinct, List children) { Preconditions.checkArgument(children.size() == 1); - return new MultiDistinctSum0(mustUseMultiDistinctAgg, getFunctionParams(false, children)); + return new MultiDistinctSum0(getFunctionParams(false, children)); } @Override @@ -92,16 +83,6 @@ public R accept(ExpressionVisitor visitor, C context) { return visitor.visitMultiDistinctSum0(this, context); } - @Override - public boolean mustUseMultiDistinctAgg() { - return mustUseMultiDistinctAgg; - } - - @Override - public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) { - return new MultiDistinctSum0(mustUseMultiDistinctAgg, getFunctionParams(children)); - } - @Override public Expression resultForEmptyInput() { DataType dataType = getDataType(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java index 2709c4bcfe906e..ab8842f730112c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/MultiDistinction.java @@ -24,5 +24,4 @@ * base class of multi-distinct agg function */ public interface MultiDistinction extends TreeNode { - Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg); } diff --git a/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out b/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out index a405272b3d3ccf..ba448a756a2c6d 100644 --- a/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out +++ b/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out @@ -897,3 +897,25 @@ PhysicalResultSink --------hashAgg[GLOBAL] ----------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_dst_key1] +-- !multi_distinct_count_and_count_distinct_multi_expr -- +2 20 + +-- !multi_distinct_sum_and_count_distinct_multi_expr -- +1 20 + +-- !multi_distinct_sum0_and_count_distinct_multi_expr -- +1 20 + +-- !multi_distinct_group_concat_and_count_distinct_multi_expr -- +PhysicalResultSink +--hashAgg[DISTINCT_GLOBAL] +----PhysicalDistribute[DistributionSpecGather] +------hashAgg[DISTINCT_LOCAL] +--------hashAgg[GLOBAL] +----------PhysicalDistribute[DistributionSpecHash] +------------hashAgg[LOCAL] +--------------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_id] + +-- !multi_distinct_count_2_same_args -- +10 + diff --git a/regression-test/suites/nereids_rules_p0/agg_strategy/agg_strategy.groovy b/regression-test/suites/nereids_rules_p0/agg_strategy/agg_strategy.groovy index 2c70208e69f1a9..72ee1b92efb415 100644 --- a/regression-test/suites/nereids_rules_p0/agg_strategy/agg_strategy.groovy +++ b/regression-test/suites/nereids_rules_p0/agg_strategy/agg_strategy.groovy @@ -139,4 +139,27 @@ suite("agg_strategy") { qt_group_concat_distinct_key_is_varchar_and_distribute_key """explain shape plan select group_concat(distinct dst_key1 ,' ') from t_gbykey_10_dstkey_10_1000_dst_key1;""" + + // multi_distinct and count distinct multi expr + qt_multi_distinct_count_and_count_distinct_multi_expr """ + select multi_distinct_count(dst_key1), count(distinct dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id; + """ + qt_multi_distinct_sum_and_count_distinct_multi_expr """ + select multi_distinct_sum(dst_key1), count(distinct dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id; + """ + qt_multi_distinct_sum0_and_count_distinct_multi_expr """ + select multi_distinct_sum0(dst_key1), count(distinct dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id; + """ + qt_multi_distinct_group_concat_and_count_distinct_multi_expr """ + explain shape plan + select multi_distinct_group_concat(dst_key1 order by dst_key2), count(distinct dst_key1,dst_key2) from t_gbykey_10_dstkey_10_1000_id; + """ + + // multi_distinct_count only accept one arg + test { + sql "select multi_distinct_count(dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id;" + exception "multi_distinct_count(dst_key1, dst_key2), MultiDistinctCount's children size must be 1" + } + // multi_distinct_count only accept 2 same args + qt_multi_distinct_count_2_same_args "select multi_distinct_count(dst_key2,dst_key2) from t_gbykey_2_dstkey_10_30_id;" } \ No newline at end of file diff --git a/regression-test/suites/nereids_rules_p0/agg_strategy/load.groovy b/regression-test/suites/nereids_rules_p0/agg_strategy/load.groovy index f953cec1bfb6cd..f17f64ab5a4cd9 100644 --- a/regression-test/suites/nereids_rules_p0/agg_strategy/load.groovy +++ b/regression-test/suites/nereids_rules_p0/agg_strategy/load.groovy @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -suite("agg_strategy") { +suite("load") { sql "set global enable_auto_analyze=false" // ndv is high sql "drop table if exists t_gbykey_10_dstkey_10_1000_id"