Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinction;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
Expand Down Expand Up @@ -270,15 +269,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi
normalizedAggFuncsToSlotContext.pushDownToNamedExpression(normalizedAggFuncs)
);
// create new agg node
ImmutableList<NamedExpression> aggOutput = normalizedAggOutputBuilder.build();
ImmutableList.Builder<NamedExpression> newAggOutputBuilder
= ImmutableList.builderWithExpectedSize(aggOutput.size());
for (NamedExpression output : aggOutput) {
Expression rewrittenExpr = output.rewriteDownShortCircuit(
e -> e instanceof MultiDistinction ? ((MultiDistinction) e).withMustUseMultiDistinctAgg(true) : e);
newAggOutputBuilder.add((NamedExpression) rewrittenExpr);
}
ImmutableList<NamedExpression> normalizedAggOutput = newAggOutputBuilder.build();
ImmutableList<NamedExpression> normalizedAggOutput = normalizedAggOutputBuilder.build();

// create upper projects by normalize all output exprs in old LogicalAggregate
// In aggregateOutput, the expressions inside the agg function can be rewritten
Expand Down Expand Up @@ -313,7 +304,7 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi
for (Slot slot : missingSlotsInAggregate) {
Alias anyValue = new Alias(new AnyValue(slot), slot.getName());
replaceMap.put(slot, anyValue.toSlot());
newAggOutputBuilder.add(anyValue);
normalizedAggOutputBuilder.add(anyValue);
}
upperProjects = upperProjects.stream()
.map(e -> (NamedExpression) ExpressionUtils.replace(e, replaceMap))
Expand All @@ -328,10 +319,10 @@ private LogicalPlan normalizeAgg(LogicalAggregate<Plan> aggregate, Optional<Logi
} else {
bottomPlan = aggregate.child();
}
// NOTICE: we must call newAggOutputBuilder.build() here, newAggOutputBuilder could be updated if we need
// to process non-standard aggregate: SELECT c1, c2 FROM t GROUP BY c1
// NOTICE: we must call normalizedAggOutputBuilder.build() here, normalizedAggOutputBuilder could be updated
// if we need to process non-standard aggregate: SELECT c1, c2 FROM t GROUP BY c1
LogicalAggregate<?> newAggregate =
aggregate.withNormalized(normalizedGroupExprs, newAggOutputBuilder.build(), bottomPlan);
aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutputBuilder.build(), bottomPlan);
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx);
LogicalProject<Plan> project = eliminateGroupByConstant(groupByExprContext, rewriteContext,
normalizedGroupExprs, normalizedAggOutput, bottomProjects, aggregate, upperProjects, newAggregate);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public List<Rule> buildRules() {
.toRule(RuleType.DISTINCT_AGGREGATE_SPLIT),
logicalAggregate()
.when(agg -> agg.getGroupByExpressions().isEmpty()
&& agg.mustUseMultiDistinctAgg())
&& agg.mustUseMultiDistinctAgg() && !AggregateUtils.containsCountDistinctMultiExpr(agg))
.then(this::convertToMultiDistinct)
.toRule(RuleType.PROCESS_SCALAR_AGG_MUST_USE_MULTI_DISTINCT)
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
Expand All @@ -32,6 +33,7 @@
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.LinkedHashSet;
import java.util.List;

/** MultiDistinctCount */
Expand All @@ -40,7 +42,6 @@ public class MultiDistinctCount extends NotNullableAggregateFunction
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(AnyDataType.INSTANCE_WITHOUT_INDEX)
);
private final boolean mustUseMultiDistinctAgg;

// MultiDistinctCount is created in AggregateStrategies phase
// can't change getSignatures to use type coercion rule to add a cast expr
Expand All @@ -50,28 +51,29 @@ public MultiDistinctCount(Expression arg0, Expression... varArgs) {
}

public MultiDistinctCount(boolean distinct, Expression arg0, Expression... varArgs) {
this(false, false, ExpressionUtils.mergeArguments(arg0, varArgs));
this(false, ExpressionUtils.mergeArguments(arg0, varArgs));
}

private MultiDistinctCount(boolean mustUseMultiDistinctAgg, boolean distinct, List<Expression> children) {
super("multi_distinct_count", false, children
private MultiDistinctCount(boolean distinct, List<Expression> children) {
super("multi_distinct_count", false, new LinkedHashSet<>(children)
.stream()
.map(arg -> !(arg instanceof Unbound) && arg.getDataType() instanceof DateLikeType
? new Cast(arg, BigIntType.INSTANCE) : arg)
.collect(ImmutableList.toImmutableList()));
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
if (super.children().size() > 1) {
throw new AnalysisException("MultiDistinctCount's children size must be 1");
}
}

/** constructor for withChildren and reuse signature */
protected MultiDistinctCount(boolean mustUseMultiDistinctAgg, AggregateFunctionParams functionParams) {
protected MultiDistinctCount(AggregateFunctionParams functionParams) {
super(functionParams);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}

@Override
public MultiDistinctCount withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(!children.isEmpty());
return new MultiDistinctCount(mustUseMultiDistinctAgg, getFunctionParams(false, children));
Preconditions.checkArgument(children.size() == 1, "MultiDistinctCount's children size must be 1");
return new MultiDistinctCount(getFunctionParams(false, children));
}

@Override
Expand All @@ -84,16 +86,6 @@ public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg;
}

@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctCount(mustUseMultiDistinctAgg, getFunctionParams(children));
}

@Override
public Expression resultForEmptyInput() {
return new BigIntLiteral(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ public class MultiDistinctGroupConcat extends NullableAggregateFunction
.varArgs(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT, AnyDataType.INSTANCE_WITHOUT_INDEX)
);

private final boolean mustUseMultiDistinctAgg;
private final int nonOrderArguments;

/**
Expand All @@ -64,26 +63,19 @@ public MultiDistinctGroupConcat(Expression arg, Expression... others) {
/**
* constructor with argument list.
*/
public MultiDistinctGroupConcat(boolean alwaysNullable, List<Expression> args) {
this(false, alwaysNullable, args);
}

private MultiDistinctGroupConcat(boolean alwaysNullable, Expression arg,
Expression... others) {
this(alwaysNullable, ExpressionUtils.mergeArguments(arg, others));
}

private MultiDistinctGroupConcat(boolean mustUseMultiDistinctAgg, boolean alwaysNullable, List<Expression> args) {
public MultiDistinctGroupConcat(boolean alwaysNullable, List<Expression> args) {
super("multi_distinct_group_concat", false, alwaysNullable, args);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
this.nonOrderArguments = findOrderExprIndex(children);
}

/** constructor for withChildren and reuse signature */
protected MultiDistinctGroupConcat(
boolean mustUseMultiDistinctAgg, NullableAggregateFunctionParams functionParams) {
protected MultiDistinctGroupConcat(NullableAggregateFunctionParams functionParams) {
super(functionParams);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
this.nonOrderArguments = findOrderExprIndex(children);
}

Expand All @@ -95,15 +87,15 @@ public boolean nullable() {

@Override
public MultiDistinctGroupConcat withAlwaysNullable(boolean alwaysNullable) {
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, getAlwaysNullableFunctionParams(alwaysNullable));
return new MultiDistinctGroupConcat(getAlwaysNullableFunctionParams(alwaysNullable));
}

/**
* withDistinctAndChildren.
*/
@Override
public MultiDistinctGroupConcat withDistinctAndChildren(boolean distinct, List<Expression> children) {
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, getFunctionParams(false, children));
return new MultiDistinctGroupConcat(getFunctionParams(false, children));
}

@Override
Expand All @@ -126,16 +118,6 @@ public List<FunctionSignature> getSignatures() {
}
}

@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg || children.stream().anyMatch(OrderExpression.class::isInstance);
}

@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctGroupConcat(mustUseMultiDistinctAgg, alwaysNullable, children);
}

private int findOrderExprIndex(List<Expression> children) {
Preconditions.checkArgument(children().size() >= 1, "children's size should >= 1");
boolean foundOrderExpr = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,6 @@
/** MultiDistinctSum */
public class MultiDistinctSum extends NullableAggregateFunction implements UnaryExpression,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {

private final boolean mustUseMultiDistinctAgg;

public MultiDistinctSum(Expression arg0) {
this(false, arg0);
}
Expand All @@ -45,19 +42,12 @@ public MultiDistinctSum(boolean distinct, Expression arg0) {
}

public MultiDistinctSum(boolean distinct, boolean alwaysNullable, Expression arg0) {
this(false, false, alwaysNullable, arg0);
}

private MultiDistinctSum(boolean mustUseMultiDistinctAgg, boolean distinct,
boolean alwaysNullable, Expression arg0) {
super("multi_distinct_sum", false, alwaysNullable, arg0);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}

/** constructor for withChildren and reuse signature */
private MultiDistinctSum(boolean mustUseMultiDistinctAgg, NullableAggregateFunctionParams functionParams) {
private MultiDistinctSum(NullableAggregateFunctionParams functionParams) {
super(functionParams);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}

@Override
Expand All @@ -81,27 +71,17 @@ public FunctionSignature searchSignature(List<FunctionSignature> signatures) {

@Override
public NullableAggregateFunction withAlwaysNullable(boolean alwaysNullable) {
return new MultiDistinctSum(mustUseMultiDistinctAgg, getAlwaysNullableFunctionParams(alwaysNullable));
return new MultiDistinctSum(getAlwaysNullableFunctionParams(alwaysNullable));
}

@Override
public MultiDistinctSum withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new MultiDistinctSum(mustUseMultiDistinctAgg, getFunctionParams(false, children));
return new MultiDistinctSum(getFunctionParams(false, children));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiDistinctSum(this, context);
}

@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg;
}

@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctSum(mustUseMultiDistinctAgg, getFunctionParams(children));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,17 @@
/** MultiDistinctSum0 */
public class MultiDistinctSum0 extends NotNullableAggregateFunction implements UnaryExpression,
ExplicitlyCastableSignature, ComputePrecisionForSum, MultiDistinction {

private final boolean mustUseMultiDistinctAgg;

public MultiDistinctSum0(Expression arg0) {
this(false, arg0);
}

public MultiDistinctSum0(boolean distinct, Expression arg0) {
this(false, false, arg0);
}

private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, boolean distinct, Expression arg0) {
super("multi_distinct_sum0", false, arg0);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}

/** constructor for withChildren and reuse signature */
private MultiDistinctSum0(boolean mustUseMultiDistinctAgg, AggregateFunctionParams functionParams) {
public MultiDistinctSum0(AggregateFunctionParams functionParams) {
super(functionParams);
this.mustUseMultiDistinctAgg = mustUseMultiDistinctAgg;
}

@Override
Expand All @@ -84,24 +75,14 @@ public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
@Override
public MultiDistinctSum0 withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new MultiDistinctSum0(mustUseMultiDistinctAgg, getFunctionParams(false, children));
return new MultiDistinctSum0(getFunctionParams(false, children));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitMultiDistinctSum0(this, context);
}

@Override
public boolean mustUseMultiDistinctAgg() {
return mustUseMultiDistinctAgg;
}

@Override
public Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg) {
return new MultiDistinctSum0(mustUseMultiDistinctAgg, getFunctionParams(children));
}

@Override
public Expression resultForEmptyInput() {
DataType dataType = getDataType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,4 @@
* base class of multi-distinct agg function
*/
public interface MultiDistinction extends TreeNode<Expression> {
Expression withMustUseMultiDistinctAgg(boolean mustUseMultiDistinctAgg);
}
Original file line number Diff line number Diff line change
Expand Up @@ -897,3 +897,25 @@ PhysicalResultSink
--------hashAgg[GLOBAL]
----------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_dst_key1]

-- !multi_distinct_count_and_count_distinct_multi_expr --
2 20

-- !multi_distinct_sum_and_count_distinct_multi_expr --
1 20

-- !multi_distinct_sum0_and_count_distinct_multi_expr --
1 20

-- !multi_distinct_group_concat_and_count_distinct_multi_expr --
PhysicalResultSink
--hashAgg[DISTINCT_GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[DISTINCT_LOCAL]
--------hashAgg[GLOBAL]
----------PhysicalDistribute[DistributionSpecHash]
------------hashAgg[LOCAL]
--------------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_id]

-- !multi_distinct_count_2_same_args --
10

Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,27 @@ suite("agg_strategy") {

qt_group_concat_distinct_key_is_varchar_and_distribute_key """explain shape plan
select group_concat(distinct dst_key1 ,' ') from t_gbykey_10_dstkey_10_1000_dst_key1;"""

// multi_distinct and count distinct multi expr
qt_multi_distinct_count_and_count_distinct_multi_expr """
select multi_distinct_count(dst_key1), count(distinct dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id;
"""
qt_multi_distinct_sum_and_count_distinct_multi_expr """
select multi_distinct_sum(dst_key1), count(distinct dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id;
"""
qt_multi_distinct_sum0_and_count_distinct_multi_expr """
select multi_distinct_sum0(dst_key1), count(distinct dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id;
"""
qt_multi_distinct_group_concat_and_count_distinct_multi_expr """
explain shape plan
select multi_distinct_group_concat(dst_key1 order by dst_key2), count(distinct dst_key1,dst_key2) from t_gbykey_10_dstkey_10_1000_id;
"""

// multi_distinct_count only accept one arg
test {
sql "select multi_distinct_count(dst_key1,dst_key2) from t_gbykey_2_dstkey_10_30_id;"
exception "multi_distinct_count(dst_key1, dst_key2), MultiDistinctCount's children size must be 1"
}
// multi_distinct_count only accept 2 same args
qt_multi_distinct_count_2_same_args "select multi_distinct_count(dst_key2,dst_key2) from t_gbykey_2_dstkey_10_30_id;"
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.

suite("agg_strategy") {
suite("load") {
sql "set global enable_auto_analyze=false"
// ndv is high
sql "drop table if exists t_gbykey_10_dstkey_10_1000_id"
Expand Down
Loading